get_mdulations added from blepping and minor changes

This commit is contained in:
silveroxides 2025-03-25 21:38:11 +01:00
parent 9f70cfbc42
commit f04b502ab6
3 changed files with 60 additions and 132 deletions

View File

@ -39,6 +39,16 @@ class ChromaParams:
n_layers: int
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)
class Chroma(nn.Module):
"""
Transformer model for flow matching on sequences.
@ -108,118 +118,34 @@ class Chroma(nn.Module):
self.skip_mmdit = []
self.skip_dit = []
self.lite = False
@staticmethod
def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19):
"""
Distributes slices of the tensor into the block_dict as ModulationOut objects.
Args:
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
"""
batch_size, vectors, dim = tensor.shape
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
# This function slices up the modulations tensor which has the following layout:
# single : num_single_blocks * 3 elements
# double_img : num_double_blocks * 6 elements
# double_txt : num_double_blocks * 6 elements
# final : 2 elements
if block_type == "final":
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
single_block_count = self.params.depth_single_blocks
double_block_count = self.params.depth
offset = 3 * idx
if block_type == "single":
return ChromaModulationOut.from_offset(tensor, offset)
# Double block modulations are 6 elements so we double 3 * idx.
offset *= 2
if block_type in {"double_img", "double_txt"}:
# Advance past the single block modulations.
offset += 3 * single_block_count
if block_type == "double_txt":
# Advance past the double block img modulations.
offset += 6 * double_block_count
return (
ChromaModulationOut.from_offset(tensor, offset),
ChromaModulationOut.from_offset(tensor, offset + 3),
)
raise ValueError("Bad block_type")
block_dict = {}
# HARD CODED VALUES! lookup table for the generated vectors
# Add 38 single mod blocks
for i in range(single_block_count):
key = f"single_blocks.{i}.modulation.lin"
block_dict[key] = None
# Add 19 image double blocks
for i in range(double_blocks_count):
key = f"double_blocks.{i}.img_mod.lin"
block_dict[key] = None
# Add 19 text double blocks
for i in range(double_blocks_count):
key = f"double_blocks.{i}.txt_mod.lin"
block_dict[key] = None
# Add the final layer
block_dict["final_layer.adaLN_modulation.1"] = None
# # 6.2b version
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
idx = 0 # Index to keep track of the vector slices
for key in block_dict.keys():
if "single_blocks" in key:
# Single block: 1 ModulationOut
block_dict[key] = ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
idx += 3 # Advance by 3 vectors
elif "img_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "txt_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "final_layer" in key:
# Final layer: 1 ModulationOut
block_dict[key] = [
tensor[:, idx:idx+1, :],
tensor[:, idx+1:idx+2, :],
]
idx += 2 # Advance by 2 vectors
# elif "lite_double_blocks.4.img_mod" in key:
# # Double block: List of 2 ModulationOut
# double_block = []
# for _ in range(2): # Create 2 ModulationOut objects
# double_block.append(
# ModulationOut(
# shift=tensor[:, idx:idx+1, :],
# scale=tensor[:, idx+1:idx+2, :],
# gate=tensor[:, idx+2:idx+3, :]
# )
# )
# idx += 3 # Advance by 3 vectors per ModulationOut
# block_dict[key] = double_block
# elif "lite_double_blocks.4.txt_mod" in key:
# # Double block: List of 2 ModulationOut
# double_block = []
# for _ in range(2): # Create 2 ModulationOut objects
# double_block.append(
# ModulationOut(
# shift=tensor[:, idx:idx+1, :],
# scale=tensor[:, idx+1:idx+2, :],
# gate=tensor[:, idx+2:idx+3, :]
# )
# )
# idx += 3 # Advance by 3 vectors per ModulationOut
# block_dict[key] = double_block
return block_dict
def forward_orig(
self,
@ -257,8 +183,6 @@ class Chroma(nn.Module):
mod_vectors = self.distilled_guidance_layer(input_vec)
mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
@ -267,21 +191,10 @@ class Chroma(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i not in self.skip_mmdit:
guidance_index = i
# if lite we change block 4 guidance with lite guidance
# and offset the guidance by 11 blocks after block 4
# if self.lite and i == 4:
# img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
# txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
# elif self.lite and i > 4:
# guidance_index = i + 11
# img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
# txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
# else:
img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
double_mod = [img_mod, txt_mod]
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
self.get_modulations(mod_vectors, "double_txt", idx=i),
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -318,7 +231,7 @@ class Chroma(nn.Module):
for i, block in enumerate(self.single_blocks):
if i not in self.skip_dit:
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -345,7 +258,7 @@ class Chroma(nn.Module):
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel):
return out
class Chroma(BaseModel):
chroma_model_mode=False
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
@ -1098,6 +1096,15 @@ class Chroma(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = 0.0
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
return out

View File

@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE):
"multiplier": 1.0,
"shift": 1.0,
}
latent_format = comfy.latent_formats.Flux
memory_usage_factor = 2.8
memory_usage_factor = 1.8
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
models += [SVD_img2vid]