mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 16:43:36 +00:00
get_mdulations added from blepping and minor changes
This commit is contained in:
parent
9f70cfbc42
commit
f04b502ab6
@ -39,6 +39,16 @@ class ChromaParams:
|
||||
n_layers: int
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@classmethod
|
||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||
return cls(
|
||||
shift=tensor[:, offset : offset + 1, :],
|
||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||
)
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@ -108,118 +118,34 @@ class Chroma(nn.Module):
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
@staticmethod
|
||||
def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19):
|
||||
"""
|
||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
|
||||
"""
|
||||
batch_size, vectors, dim = tensor.shape
|
||||
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||
# This function slices up the modulations tensor which has the following layout:
|
||||
# single : num_single_blocks * 3 elements
|
||||
# double_img : num_double_blocks * 6 elements
|
||||
# double_txt : num_double_blocks * 6 elements
|
||||
# final : 2 elements
|
||||
if block_type == "final":
|
||||
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||
single_block_count = self.params.depth_single_blocks
|
||||
double_block_count = self.params.depth
|
||||
offset = 3 * idx
|
||||
if block_type == "single":
|
||||
return ChromaModulationOut.from_offset(tensor, offset)
|
||||
# Double block modulations are 6 elements so we double 3 * idx.
|
||||
offset *= 2
|
||||
if block_type in {"double_img", "double_txt"}:
|
||||
# Advance past the single block modulations.
|
||||
offset += 3 * single_block_count
|
||||
if block_type == "double_txt":
|
||||
# Advance past the double block img modulations.
|
||||
offset += 6 * double_block_count
|
||||
return (
|
||||
ChromaModulationOut.from_offset(tensor, offset),
|
||||
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||
)
|
||||
raise ValueError("Bad block_type")
|
||||
|
||||
block_dict = {}
|
||||
|
||||
# HARD CODED VALUES! lookup table for the generated vectors
|
||||
# Add 38 single mod blocks
|
||||
for i in range(single_block_count):
|
||||
key = f"single_blocks.{i}.modulation.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 image double blocks
|
||||
for i in range(double_blocks_count):
|
||||
key = f"double_blocks.{i}.img_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 text double blocks
|
||||
for i in range(double_blocks_count):
|
||||
key = f"double_blocks.{i}.txt_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add the final layer
|
||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
||||
# # 6.2b version
|
||||
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
||||
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
||||
|
||||
|
||||
idx = 0 # Index to keep track of the vector slices
|
||||
|
||||
for key in block_dict.keys():
|
||||
if "single_blocks" in key:
|
||||
# Single block: 1 ModulationOut
|
||||
block_dict[key] = ModulationOut(
|
||||
shift=tensor[:, idx:idx+1, :],
|
||||
scale=tensor[:, idx+1:idx+2, :],
|
||||
gate=tensor[:, idx+2:idx+3, :]
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors
|
||||
|
||||
elif "img_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx:idx+1, :],
|
||||
scale=tensor[:, idx+1:idx+2, :],
|
||||
gate=tensor[:, idx+2:idx+3, :]
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "txt_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx:idx+1, :],
|
||||
scale=tensor[:, idx+1:idx+2, :],
|
||||
gate=tensor[:, idx+2:idx+3, :]
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "final_layer" in key:
|
||||
# Final layer: 1 ModulationOut
|
||||
block_dict[key] = [
|
||||
tensor[:, idx:idx+1, :],
|
||||
tensor[:, idx+1:idx+2, :],
|
||||
]
|
||||
idx += 2 # Advance by 2 vectors
|
||||
|
||||
# elif "lite_double_blocks.4.img_mod" in key:
|
||||
# # Double block: List of 2 ModulationOut
|
||||
# double_block = []
|
||||
# for _ in range(2): # Create 2 ModulationOut objects
|
||||
# double_block.append(
|
||||
# ModulationOut(
|
||||
# shift=tensor[:, idx:idx+1, :],
|
||||
# scale=tensor[:, idx+1:idx+2, :],
|
||||
# gate=tensor[:, idx+2:idx+3, :]
|
||||
# )
|
||||
# )
|
||||
# idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
# block_dict[key] = double_block
|
||||
|
||||
# elif "lite_double_blocks.4.txt_mod" in key:
|
||||
# # Double block: List of 2 ModulationOut
|
||||
# double_block = []
|
||||
# for _ in range(2): # Create 2 ModulationOut objects
|
||||
# double_block.append(
|
||||
# ModulationOut(
|
||||
# shift=tensor[:, idx:idx+1, :],
|
||||
# scale=tensor[:, idx+1:idx+2, :],
|
||||
# gate=tensor[:, idx+2:idx+3, :]
|
||||
# )
|
||||
# )
|
||||
# idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
# block_dict[key] = double_block
|
||||
|
||||
return block_dict
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@ -257,8 +183,6 @@ class Chroma(nn.Module):
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
@ -267,21 +191,10 @@ class Chroma(nn.Module):
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if i not in self.skip_mmdit:
|
||||
guidance_index = i
|
||||
# if lite we change block 4 guidance with lite guidance
|
||||
# and offset the guidance by 11 blocks after block 4
|
||||
# if self.lite and i == 4:
|
||||
# img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
|
||||
# txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
|
||||
# elif self.lite and i > 4:
|
||||
# guidance_index = i + 11
|
||||
# img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
|
||||
# txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
|
||||
# else:
|
||||
img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
|
||||
double_mod = [img_mod, txt_mod]
|
||||
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -318,7 +231,7 @@ class Chroma(nn.Module):
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if i not in self.skip_dit:
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -345,7 +258,7 @@ class Chroma(nn.Module):
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
|
@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel):
|
||||
return out
|
||||
|
||||
class Chroma(BaseModel):
|
||||
chroma_model_mode=False
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||
|
||||
@ -1098,6 +1096,15 @@ class Chroma(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
# upscale the attention mask, since now we
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
guidance = 0.0
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
|
||||
return out
|
||||
|
@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE):
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.Flux
|
||||
memory_usage_factor = 2.8
|
||||
|
||||
memory_usage_factor = 1.8
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
Loading…
Reference in New Issue
Block a user