From f04b502ab616318934de6b3a0d95b65974538978 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Tue, 25 Mar 2025 21:38:11 +0100 Subject: [PATCH] get_mdulations added from blepping and minor changes --- comfy/ldm/chroma/model.py | 171 ++++++++++---------------------------- comfy/model_base.py | 11 ++- comfy/supported_models.py | 10 ++- 3 files changed, 60 insertions(+), 132 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index b3b03dcd..a956ad05 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -39,6 +39,16 @@ class ChromaParams: n_layers: int +class ChromaModulationOut(ModulationOut): + @classmethod + def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: + return cls( + shift=tensor[:, offset : offset + 1, :], + scale=tensor[:, offset + 1 : offset + 2, :], + gate=tensor[:, offset + 2 : offset + 3, :], + ) + + class Chroma(nn.Module): """ Transformer model for flow matching on sequences. @@ -108,118 +118,34 @@ class Chroma(nn.Module): self.skip_mmdit = [] self.skip_dit = [] self.lite = False - @staticmethod - def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19): - """ - Distributes slices of the tensor into the block_dict as ModulationOut objects. - Args: - tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. - """ - batch_size, vectors, dim = tensor.shape + def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0): + # This function slices up the modulations tensor which has the following layout: + # single : num_single_blocks * 3 elements + # double_img : num_double_blocks * 6 elements + # double_txt : num_double_blocks * 6 elements + # final : 2 elements + if block_type == "final": + return (tensor[:, -2:-1, :], tensor[:, -1:, :]) + single_block_count = self.params.depth_single_blocks + double_block_count = self.params.depth + offset = 3 * idx + if block_type == "single": + return ChromaModulationOut.from_offset(tensor, offset) + # Double block modulations are 6 elements so we double 3 * idx. + offset *= 2 + if block_type in {"double_img", "double_txt"}: + # Advance past the single block modulations. + offset += 3 * single_block_count + if block_type == "double_txt": + # Advance past the double block img modulations. + offset += 6 * double_block_count + return ( + ChromaModulationOut.from_offset(tensor, offset), + ChromaModulationOut.from_offset(tensor, offset + 3), + ) + raise ValueError("Bad block_type") - block_dict = {} - - # HARD CODED VALUES! lookup table for the generated vectors - # Add 38 single mod blocks - for i in range(single_block_count): - key = f"single_blocks.{i}.modulation.lin" - block_dict[key] = None - - # Add 19 image double blocks - for i in range(double_blocks_count): - key = f"double_blocks.{i}.img_mod.lin" - block_dict[key] = None - - # Add 19 text double blocks - for i in range(double_blocks_count): - key = f"double_blocks.{i}.txt_mod.lin" - block_dict[key] = None - - # Add the final layer - block_dict["final_layer.adaLN_modulation.1"] = None - # # 6.2b version - # block_dict["lite_double_blocks.4.img_mod.lin"] = None - # block_dict["lite_double_blocks.4.txt_mod.lin"] = None - - - idx = 0 # Index to keep track of the vector slices - - for key in block_dict.keys(): - if "single_blocks" in key: - # Single block: 1 ModulationOut - block_dict[key] = ModulationOut( - shift=tensor[:, idx:idx+1, :], - scale=tensor[:, idx+1:idx+2, :], - gate=tensor[:, idx+2:idx+3, :] - ) - idx += 3 # Advance by 3 vectors - - elif "img_mod" in key: - # Double block: List of 2 ModulationOut - double_block = [] - for _ in range(2): # Create 2 ModulationOut objects - double_block.append( - ModulationOut( - shift=tensor[:, idx:idx+1, :], - scale=tensor[:, idx+1:idx+2, :], - gate=tensor[:, idx+2:idx+3, :] - ) - ) - idx += 3 # Advance by 3 vectors per ModulationOut - block_dict[key] = double_block - - elif "txt_mod" in key: - # Double block: List of 2 ModulationOut - double_block = [] - for _ in range(2): # Create 2 ModulationOut objects - double_block.append( - ModulationOut( - shift=tensor[:, idx:idx+1, :], - scale=tensor[:, idx+1:idx+2, :], - gate=tensor[:, idx+2:idx+3, :] - ) - ) - idx += 3 # Advance by 3 vectors per ModulationOut - block_dict[key] = double_block - - elif "final_layer" in key: - # Final layer: 1 ModulationOut - block_dict[key] = [ - tensor[:, idx:idx+1, :], - tensor[:, idx+1:idx+2, :], - ] - idx += 2 # Advance by 2 vectors - - # elif "lite_double_blocks.4.img_mod" in key: - # # Double block: List of 2 ModulationOut - # double_block = [] - # for _ in range(2): # Create 2 ModulationOut objects - # double_block.append( - # ModulationOut( - # shift=tensor[:, idx:idx+1, :], - # scale=tensor[:, idx+1:idx+2, :], - # gate=tensor[:, idx+2:idx+3, :] - # ) - # ) - # idx += 3 # Advance by 3 vectors per ModulationOut - # block_dict[key] = double_block - - # elif "lite_double_blocks.4.txt_mod" in key: - # # Double block: List of 2 ModulationOut - # double_block = [] - # for _ in range(2): # Create 2 ModulationOut objects - # double_block.append( - # ModulationOut( - # shift=tensor[:, idx:idx+1, :], - # scale=tensor[:, idx+1:idx+2, :], - # gate=tensor[:, idx+2:idx+3, :] - # ) - # ) - # idx += 3 # Advance by 3 vectors per ModulationOut - # block_dict[key] = double_block - - return block_dict def forward_orig( self, @@ -257,8 +183,6 @@ class Chroma(nn.Module): mod_vectors = self.distilled_guidance_layer(input_vec) - mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19) - txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) @@ -267,21 +191,10 @@ class Chroma(nn.Module): blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): if i not in self.skip_mmdit: - guidance_index = i - # if lite we change block 4 guidance with lite guidance - # and offset the guidance by 11 blocks after block 4 - # if self.lite and i == 4: - # img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"] - # txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"] - # elif self.lite and i > 4: - # guidance_index = i + 11 - # img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] - # txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] - # else: - img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] - double_mod = [img_mod, txt_mod] - + double_mod = ( + self.get_modulations(mod_vectors, "double_img", idx=i), + self.get_modulations(mod_vectors, "double_txt", idx=i), + ) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -318,7 +231,7 @@ class Chroma(nn.Module): for i, block in enumerate(self.single_blocks): if i not in self.skip_dit: - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -345,7 +258,7 @@ class Chroma(nn.Module): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + final_mod = self.get_modulations(mod_vectors, "final") img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img diff --git a/comfy/model_base.py b/comfy/model_base.py index 05e242b8..13349f71 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel): return out class Chroma(BaseModel): - chroma_model_mode=False - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) @@ -1098,6 +1096,15 @@ class Chroma(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) # upscale the attention mask, since now we + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + shape = kwargs["noise"].shape + mask_ref_size = kwargs["attention_mask_img_shape"] + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = 0.0 out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,))) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index da5f3abc..c113a023 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE): "multiplier": 1.0, "shift": 1.0, } + latent_format = comfy.latent_formats.Flux - memory_usage_factor = 2.8 + + memory_usage_factor = 1.8 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device) return out + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect)) + models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma] models += [SVD_img2vid]