From 0a4bc660d4e1d60bf027f6317b86cb482c720069 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sat, 22 Mar 2025 15:27:11 +0100 Subject: [PATCH 01/17] Upload files for Chroma Implementation --- comfy/ldm/chroma/layers.py | 273 +++++++++++++++++++++++++ comfy/ldm/chroma/math.py | 44 ++++ comfy/ldm/chroma/model.py | 369 ++++++++++++++++++++++++++++++++++ comfy/model_base.py | 55 +++++ comfy/model_detection.py | 26 +++ comfy/sd.py | 5 + comfy/supported_models.py | 21 +- comfy/text_encoders/chroma.py | 44 ++++ nodes.py | 4 +- 9 files changed, 839 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/chroma/layers.py create mode 100644 comfy/ldm/chroma/math.py create mode 100644 comfy/ldm/chroma/model.py create mode 100644 comfy/text_encoders/chroma.py diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py new file mode 100644 index 000000000..606d9688c --- /dev/null +++ b/comfy/ldm/chroma/layers.py @@ -0,0 +1,273 @@ +import math +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from .math import attention, rope +import comfy.ops +import comfy.ldm.common_dit + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.silu = nn.SiLU() + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) + + def forward(self, x: Tensor): + return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) + self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + + def forward(self, vec: Tensor) -> tuple: + out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + + +class Approximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None): + super().__init__() + self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + + self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + + self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + + self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.flipped_img_txt = flipped_img_txt + + def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + attn = attention(torch.cat((txt_q, img_q), dim=2), + torch.cat((txt_k, img_k), dim=2), + torch.cat((txt_v, img_v), dim=2), + pe=pe, mask=attn_mask) + + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) + + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float = None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + # proj and mlp_out + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + + self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) + + self.hidden_size = hidden_size + self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.mlp_act = nn.GELU(approximate="tanh") + + def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: + mod = vec + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + x += mod.gate * output + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/comfy/ldm/chroma/math.py b/comfy/ldm/chroma/math.py new file mode 100644 index 000000000..36b67931c --- /dev/null +++ b/comfy/ldm/chroma/math.py @@ -0,0 +1,44 @@ +import torch +from einops import rearrange +from torch import Tensor + +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: + q_shape = q.shape + k_shape = k.shape + + q = q.float().reshape(*q.shape[:-1], -1, 1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) + + heads = q.shape[1] + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): + device = torch.device("cpu") + else: + device = pos.device + + scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.to(dtype=torch.float32, device=pos.device) + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py new file mode 100644 index 000000000..190e19008 --- /dev/null +++ b/comfy/ldm/chroma/model.py @@ -0,0 +1,369 @@ +#Original code can be found on: https://github.com/black-forest-labs/flux + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange, repeat +import comfy.ldm.common_dit +from .common import pad_to_patch_size, rms_norm + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, + Approximator, + ModulationOut +) + + +@dataclass +class ChromaParams: + in_channels: int + out_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list + theta: int + patch_size: int + qkv_bias: bool + in_dim: int + out_dim: int + hidden_dim: int + n_layers: int + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + params = ChromaParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + self.guidance_in = ( + MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity() + ) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + for _ in range(params.depth_single_blocks) + ] + ) + + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + @staticmethod + def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19): + """ + Distributes slices of the tensor into the block_dict as ModulationOut objects. + + Args: + tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. + """ + batch_size, vectors, dim = tensor.shape + + block_dict = {} + + # HARD CODED VALUES! lookup table for the generated vectors + # Add 38 single mod blocks + for i in range(single_block_count): + key = f"single_blocks.{i}.modulation.lin" + block_dict[key] = None + + # Add 19 image double blocks + for i in range(double_blocks_count): + key = f"double_blocks.{i}.img_mod.lin" + block_dict[key] = None + + # Add 19 text double blocks + for i in range(double_blocks_count): + key = f"double_blocks.{i}.txt_mod.lin" + block_dict[key] = None + + # Add the final layer + block_dict["final_layer.adaLN_modulation.1"] = None + # # 6.2b version + # block_dict["lite_double_blocks.4.img_mod.lin"] = None + # block_dict["lite_double_blocks.4.txt_mod.lin"] = None + + + idx = 0 # Index to keep track of the vector slices + + for key in block_dict.keys(): + if "single_blocks" in key: + # Single block: 1 ModulationOut + block_dict[key] = ModulationOut( + shift=tensor[:, idx:idx+1, :], + scale=tensor[:, idx+1:idx+2, :], + gate=tensor[:, idx+2:idx+3, :] + ) + idx += 3 # Advance by 3 vectors + + elif "img_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx:idx+1, :], + scale=tensor[:, idx+1:idx+2, :], + gate=tensor[:, idx+2:idx+3, :] + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "txt_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx:idx+1, :], + scale=tensor[:, idx+1:idx+2, :], + gate=tensor[:, idx+2:idx+3, :] + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "final_layer" in key: + # Final layer: 1 ModulationOut + block_dict[key] = [ + tensor[:, idx:idx+1, :], + tensor[:, idx+1:idx+2, :], + ] + idx += 2 # Advance by 2 vectors + + # elif "lite_double_blocks.4.img_mod" in key: + # # Double block: List of 2 ModulationOut + # double_block = [] + # for _ in range(2): # Create 2 ModulationOut objects + # double_block.append( + # ModulationOut( + # shift=tensor[:, idx:idx+1, :], + # scale=tensor[:, idx+1:idx+2, :], + # gate=tensor[:, idx+2:idx+3, :] + # ) + # ) + # idx += 3 # Advance by 3 vectors per ModulationOut + # block_dict[key] = double_block + + # elif "lite_double_blocks.4.txt_mod" in key: + # # Double block: List of 2 ModulationOut + # double_block = [] + # for _ in range(2): # Create 2 ModulationOut objects + # double_block.append( + # ModulationOut( + # shift=tensor[:, idx:idx+1, :], + # scale=tensor[:, idx+1:idx+2, :], + # gate=tensor[:, idx+2:idx+3, :] + # ) + # ) + # idx += 3 # Advance by 3 vectors per ModulationOut + # block_dict[key] = double_block + + return block_dict + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + guidance: Tensor = None, + control = None, + transformer_options={}, + attn_mask: Tensor = None, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + + # distilled vector guidance + mod_index_length = 344 + distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype) + # guidance = guidance * + distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) + + # get all modulation index + modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) + + mod_vectors = self.distilled_guidance_layer(input_vec) + + mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19) + + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if i not in self.skip_mmdit: + guidance_index = i + # if lite we change block 4 guidance with lite guidance + # and offset the guidance by 11 blocks after block 4 + # if self.lite and i == 4: + # img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"] + # txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"] + # elif self.lite and i > 4: + # guidance_index = i + 11 + # img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] + # txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] + # else: + img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": double_mod, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=double_mod, + pe=pe, + attn_mask=attn_mask) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if i not in self.skip_dit: + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": single_mod, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] :, ...] += add + + img = img[:, txt.shape[1] :, ...] + final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + return img + + def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + bs, c, h, w = x.shape + patch_size = 2 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] diff --git a/comfy/model_base.py b/comfy/model_base.py index eec70d5de..05e242b88 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model +import comfy.ldm.chroma.model import comfy.model_management import comfy.patcher_extension @@ -1046,3 +1047,57 @@ class Hunyuan3Dv2(BaseModel): if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class Chroma(BaseModel): + chroma_model_mode=False + + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + + def concat_cond(self, **kwargs): + try: + #Handle Flux control loras dynamically changing the img_in weight. + num_channels = self.diffusion_model.img_in.weight.shape[1] + except: + #Some cases like tensorrt might not have the weights accessible + num_channels = self.model_config.unet_config["in_channels"] + + out_channels = self.model_config.unet_config["out_channels"] + + if num_channels <= out_channels: + return None + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.resize_to_batch_size(image, noise.shape[0]) + image = self.process_latent_in(image) + if num_channels <= out_channels * 2: + return image + + #inpaint model + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.ones_like(noise)[:, :1] + + mask = torch.mean(mask, dim=1, keepdim=True) + mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center") + mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return torch.cat((image, mask), dim=1) + + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + # upscale the attention mask, since now we + guidance = 0.0 + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,))) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4217f5831..a3d366487 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -154,6 +154,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config + if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma + dit_config = {} + dit_config["image_model"] = "chroma" + dit_config["depth"] = 48 + dit_config["in_channels"] = 64 + patch_size = 2 + dit_config["patch_size"] = patch_size + in_key = "{}img_in.weight".format(key_prefix) + if in_key in state_dict_keys: + dit_config["in_channels"] = state_dict[in_key].shape[1] + dit_config["out_channels"] = 64 + dit_config["context_in_dim"] = 4096 + dit_config["hidden_size"] = 3072 + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 24 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') + dit_config["axes_dim"] = [16, 56, 56] + dit_config["theta"] = 10000 + dit_config["qkv_bias"] = True + dit_config["in_dim"] = 64 + dit_config["out_dim"] = 3072 + dit_config["hidden_dim"] = 5120 + dit_config["n_layers"] = 5 + return dit_config + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" diff --git a/comfy/sd.py b/comfy/sd.py index d096f496c..895f4a1a7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.chroma import comfy.model_patcher import comfy.lora @@ -700,6 +701,7 @@ class CLIPType(Enum): COSMOS = 11 LUMINA2 = 12 WAN = 13 + CHROMA = 14 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -808,6 +810,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.CHROMA: + clip_target.clip = comfy.text_encoders.chroma.chroma_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.chroma.ChromaT5Tokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fad00d35b..1520b58c7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1013,6 +1013,25 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +class Chroma(supported_models_base.BASE): + unet_config = { + "image_model": "chroma", + } + + unet_extra_config = { + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + latent_format = comfy.latent_formats.Flux + memory_usage_factor = 2.8 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma] models += [SVD_img2vid] diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py new file mode 100644 index 000000000..cf10bffcc --- /dev/null +++ b/comfy/text_encoders/chroma.py @@ -0,0 +1,44 @@ +from comfy import sd1_clip +import comfy.text_encoders.t5 +import os +from transformers import T5TokenizerFast + + +class T5XXLModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") + t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) + if t5xxl_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = t5xxl_scaled_fp8 + attention_mask = True + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ChromaT5XXL(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) + + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + + +class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + + +def chroma_te(dtype_t5=None, t5xxl_scaled_fp8=None): + class ChromaTEModel_(ChromaT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if dtype is None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ChromaTEModel_ diff --git a/nodes.py b/nodes.py index 27ef743b3..f506c8a95 100644 --- a/nodes.py +++ b/nodes.py @@ -915,7 +915,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "chroma"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -946,6 +946,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.LUMINA2 elif type == "wan": clip_type = comfy.sd.CLIPType.WAN + elif type == "chroma": + clip_type = comfy.sd.CLIPType.CHROMA else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION From 79f460150e253373c0d5cf5465e70ac0cba42d98 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sat, 22 Mar 2025 16:16:23 +0100 Subject: [PATCH 02/17] Remove trailing whitespace --- comfy/ldm/chroma/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 190e19008..624f65f21 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -275,7 +275,7 @@ class Chroma(nn.Module): # img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"] # txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"] # elif self.lite and i > 4: - # guidance_index = i + 11 + # guidance_index = i + 11 # img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] # txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] # else: From 2710f77218af804449b26702abbfdff9550c293c Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sat, 22 Mar 2025 16:18:17 +0100 Subject: [PATCH 03/17] trim more trailing whitespace..oops --- comfy/ldm/chroma/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 606d9688c..8ad3c72d4 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -126,7 +126,7 @@ class Approximator(nn.Module): def device(self): # Get the device of the module (assumes all parameters are on the same device) return next(self.parameters()).device - + def forward(self, x: Tensor) -> Tensor: x = self.in_proj(x) From ca73500269b7e7bd5ae2e301386f110bb608f215 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sat, 22 Mar 2025 16:20:08 +0100 Subject: [PATCH 04/17] remove unused imports --- comfy/ldm/chroma/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 624f65f21..b3b03dcd8 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -6,7 +6,6 @@ import torch from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit -from .common import pad_to_patch_size, rms_norm from .layers import ( DoubleStreamBlock, From fda35f37e5b55f616c13167e481183bf64237af7 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sat, 22 Mar 2025 16:49:16 +0100 Subject: [PATCH 05/17] Add supported_inference_dtypes --- comfy/supported_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1520b58c7..da5f3abc9 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1027,6 +1027,7 @@ class Chroma(supported_models_base.BASE): } latent_format = comfy.latent_formats.Flux memory_usage_factor = 2.8 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device) From 9fa34e72163af560c76850ba38ee54229a9ddf55 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sun, 23 Mar 2025 06:14:50 +0100 Subject: [PATCH 06/17] Set min_length to 0 and remove attention_mask=True --- comfy/text_encoders/chroma.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py index cf10bffcc..30b5c11cf 100644 --- a/comfy/text_encoders/chroma.py +++ b/comfy/text_encoders/chroma.py @@ -11,7 +11,6 @@ class T5XXLModel(sd1_clip.SDClipModel): if t5xxl_scaled_fp8 is not None: model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 - attention_mask = True super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -24,7 +23,7 @@ class ChromaT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=0) class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer): From fe6e1fa44f792750705e3c42fb0a4d5ed0f5c5a2 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Tue, 25 Mar 2025 07:21:08 +0100 Subject: [PATCH 07/17] Set min_length to 1 --- comfy/text_encoders/chroma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py index 30b5c11cf..5ee5d57b9 100644 --- a/comfy/text_encoders/chroma.py +++ b/comfy/text_encoders/chroma.py @@ -23,7 +23,7 @@ class ChromaT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=0) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer): From f04b502ab616318934de6b3a0d95b65974538978 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Tue, 25 Mar 2025 21:38:11 +0100 Subject: [PATCH 08/17] get_mdulations added from blepping and minor changes --- comfy/ldm/chroma/model.py | 171 ++++++++++---------------------------- comfy/model_base.py | 11 ++- comfy/supported_models.py | 10 ++- 3 files changed, 60 insertions(+), 132 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index b3b03dcd8..a956ad054 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -39,6 +39,16 @@ class ChromaParams: n_layers: int +class ChromaModulationOut(ModulationOut): + @classmethod + def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: + return cls( + shift=tensor[:, offset : offset + 1, :], + scale=tensor[:, offset + 1 : offset + 2, :], + gate=tensor[:, offset + 2 : offset + 3, :], + ) + + class Chroma(nn.Module): """ Transformer model for flow matching on sequences. @@ -108,118 +118,34 @@ class Chroma(nn.Module): self.skip_mmdit = [] self.skip_dit = [] self.lite = False - @staticmethod - def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19): - """ - Distributes slices of the tensor into the block_dict as ModulationOut objects. - Args: - tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. - """ - batch_size, vectors, dim = tensor.shape + def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0): + # This function slices up the modulations tensor which has the following layout: + # single : num_single_blocks * 3 elements + # double_img : num_double_blocks * 6 elements + # double_txt : num_double_blocks * 6 elements + # final : 2 elements + if block_type == "final": + return (tensor[:, -2:-1, :], tensor[:, -1:, :]) + single_block_count = self.params.depth_single_blocks + double_block_count = self.params.depth + offset = 3 * idx + if block_type == "single": + return ChromaModulationOut.from_offset(tensor, offset) + # Double block modulations are 6 elements so we double 3 * idx. + offset *= 2 + if block_type in {"double_img", "double_txt"}: + # Advance past the single block modulations. + offset += 3 * single_block_count + if block_type == "double_txt": + # Advance past the double block img modulations. + offset += 6 * double_block_count + return ( + ChromaModulationOut.from_offset(tensor, offset), + ChromaModulationOut.from_offset(tensor, offset + 3), + ) + raise ValueError("Bad block_type") - block_dict = {} - - # HARD CODED VALUES! lookup table for the generated vectors - # Add 38 single mod blocks - for i in range(single_block_count): - key = f"single_blocks.{i}.modulation.lin" - block_dict[key] = None - - # Add 19 image double blocks - for i in range(double_blocks_count): - key = f"double_blocks.{i}.img_mod.lin" - block_dict[key] = None - - # Add 19 text double blocks - for i in range(double_blocks_count): - key = f"double_blocks.{i}.txt_mod.lin" - block_dict[key] = None - - # Add the final layer - block_dict["final_layer.adaLN_modulation.1"] = None - # # 6.2b version - # block_dict["lite_double_blocks.4.img_mod.lin"] = None - # block_dict["lite_double_blocks.4.txt_mod.lin"] = None - - - idx = 0 # Index to keep track of the vector slices - - for key in block_dict.keys(): - if "single_blocks" in key: - # Single block: 1 ModulationOut - block_dict[key] = ModulationOut( - shift=tensor[:, idx:idx+1, :], - scale=tensor[:, idx+1:idx+2, :], - gate=tensor[:, idx+2:idx+3, :] - ) - idx += 3 # Advance by 3 vectors - - elif "img_mod" in key: - # Double block: List of 2 ModulationOut - double_block = [] - for _ in range(2): # Create 2 ModulationOut objects - double_block.append( - ModulationOut( - shift=tensor[:, idx:idx+1, :], - scale=tensor[:, idx+1:idx+2, :], - gate=tensor[:, idx+2:idx+3, :] - ) - ) - idx += 3 # Advance by 3 vectors per ModulationOut - block_dict[key] = double_block - - elif "txt_mod" in key: - # Double block: List of 2 ModulationOut - double_block = [] - for _ in range(2): # Create 2 ModulationOut objects - double_block.append( - ModulationOut( - shift=tensor[:, idx:idx+1, :], - scale=tensor[:, idx+1:idx+2, :], - gate=tensor[:, idx+2:idx+3, :] - ) - ) - idx += 3 # Advance by 3 vectors per ModulationOut - block_dict[key] = double_block - - elif "final_layer" in key: - # Final layer: 1 ModulationOut - block_dict[key] = [ - tensor[:, idx:idx+1, :], - tensor[:, idx+1:idx+2, :], - ] - idx += 2 # Advance by 2 vectors - - # elif "lite_double_blocks.4.img_mod" in key: - # # Double block: List of 2 ModulationOut - # double_block = [] - # for _ in range(2): # Create 2 ModulationOut objects - # double_block.append( - # ModulationOut( - # shift=tensor[:, idx:idx+1, :], - # scale=tensor[:, idx+1:idx+2, :], - # gate=tensor[:, idx+2:idx+3, :] - # ) - # ) - # idx += 3 # Advance by 3 vectors per ModulationOut - # block_dict[key] = double_block - - # elif "lite_double_blocks.4.txt_mod" in key: - # # Double block: List of 2 ModulationOut - # double_block = [] - # for _ in range(2): # Create 2 ModulationOut objects - # double_block.append( - # ModulationOut( - # shift=tensor[:, idx:idx+1, :], - # scale=tensor[:, idx+1:idx+2, :], - # gate=tensor[:, idx+2:idx+3, :] - # ) - # ) - # idx += 3 # Advance by 3 vectors per ModulationOut - # block_dict[key] = double_block - - return block_dict def forward_orig( self, @@ -257,8 +183,6 @@ class Chroma(nn.Module): mod_vectors = self.distilled_guidance_layer(input_vec) - mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19) - txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) @@ -267,21 +191,10 @@ class Chroma(nn.Module): blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): if i not in self.skip_mmdit: - guidance_index = i - # if lite we change block 4 guidance with lite guidance - # and offset the guidance by 11 blocks after block 4 - # if self.lite and i == 4: - # img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"] - # txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"] - # elif self.lite and i > 4: - # guidance_index = i + 11 - # img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] - # txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] - # else: - img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"] - double_mod = [img_mod, txt_mod] - + double_mod = ( + self.get_modulations(mod_vectors, "double_img", idx=i), + self.get_modulations(mod_vectors, "double_txt", idx=i), + ) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -318,7 +231,7 @@ class Chroma(nn.Module): for i, block in enumerate(self.single_blocks): if i not in self.skip_dit: - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -345,7 +258,7 @@ class Chroma(nn.Module): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + final_mod = self.get_modulations(mod_vectors, "final") img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img diff --git a/comfy/model_base.py b/comfy/model_base.py index 05e242b88..13349f714 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel): return out class Chroma(BaseModel): - chroma_model_mode=False - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) @@ -1098,6 +1096,15 @@ class Chroma(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) # upscale the attention mask, since now we + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + shape = kwargs["noise"].shape + mask_ref_size = kwargs["attention_mask_img_shape"] + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = 0.0 out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,))) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index da5f3abc9..c113a023b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE): "multiplier": 1.0, "shift": 1.0, } + latent_format = comfy.latent_formats.Flux - memory_usage_factor = 2.8 + + memory_usage_factor = 1.8 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device) return out + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect)) + models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma] models += [SVD_img2vid] From b7da8e2bc103e6cf32ba056e1f3fb94bb7f97508 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Thu, 27 Mar 2025 03:08:24 +0100 Subject: [PATCH 09/17] Add lora conversion if statement in lora.py --- comfy/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/lora.py b/comfy/lora.py index bc9f3022a..f466a5ae9 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -378,7 +378,7 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format - if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux + if isinstance(model, comfy.model_base.Flux) or isinstance(model, comfy.model_base.Chroma): #Diffusers lora Flux or a diffusers lora Chroma diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: if k.endswith(".weight"): From bf339e826559d4e2c38399d29d3759af4da3696d Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Thu, 27 Mar 2025 18:29:51 +0100 Subject: [PATCH 10/17] Update supported_models.py --- comfy/supported_models.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c113a023b..0fe97cb43 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=True, device=device) return out +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1041,6 +1053,6 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma] models += [SVD_img2vid] From 2378c63ba9c8a900bdf70f3ec4b4275a800393a5 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Thu, 27 Mar 2025 18:43:58 +0100 Subject: [PATCH 11/17] update model_base.py --- comfy/model_base.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 13349f714..9d73fa460 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -993,7 +993,8 @@ class WAN21(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) @@ -1001,23 +1002,30 @@ class WAN21(BaseModel): if image is None: image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) - - if not self.image_to_video: + if not self.image_to_video or extra_channels == image.shape[1]: return image mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4] else: - mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.repeat(1, 4, 1, 1, 1) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((mask, image), dim=1) From 73f10eaf7c4bbcbab7dab9435e3bd61d61ba0c2d Mon Sep 17 00:00:00 2001 From: silveroxides Date: Thu, 27 Mar 2025 18:55:13 +0100 Subject: [PATCH 12/17] add uptream commits --- comfy_extras/nodes_cfg.py | 45 ++++++++++++++++ comfy_extras/nodes_wan.py | 105 ++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 151 insertions(+) create mode 100644 comfy_extras/nodes_cfg.py diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 000000000..1fb686644 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index dc30eb546..2d0f31ac8 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class WanImageToVideo: @@ -49,6 +50,110 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + +class WanFunInpaintToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, + "WanFunInpaintToVideo": WanFunInpaintToVideo, } diff --git a/nodes.py b/nodes.py index f506c8a95..42e7805ec 100644 --- a/nodes.py +++ b/nodes.py @@ -2269,6 +2269,7 @@ def init_builtin_extra_nodes(): "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", + "nodes_cfg.py", ] import_failed = [] From cb6ece9a1857bbbc8dbc0341d403e8d8950ce78e Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sun, 6 Apr 2025 09:52:32 +0200 Subject: [PATCH 13/17] set modelType.FLOW, will cause beta scheduler to work properly --- comfy/model_base.py | 2 +- comfy/supported_models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2a4303620..756a479cc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1059,7 +1059,7 @@ class Hunyuan3Dv2(BaseModel): return out class Chroma(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) def concat_cond(self, **kwargs): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 0fe97cb43..9b9c45a28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1045,7 +1045,7 @@ class Chroma(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): - out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device) + out = model_base.Chroma(self, model_type=model_base.ModelType.FLOW, device=device) return out def clip_target(self, state_dict={}): From e1af4137224fd79f12b9eebf68943bd307276a9d Mon Sep 17 00:00:00 2001 From: silveroxides Date: Tue, 8 Apr 2025 12:48:25 +0200 Subject: [PATCH 14/17] Adjust memory usage factor and remove unnecessary code --- comfy/supported_models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9b9c45a28..c913545fd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1034,18 +1034,16 @@ class Chroma(supported_models_base.BASE): } sampling_settings = { - "multiplier": 1.0, - "shift": 1.0, } latent_format = comfy.latent_formats.Flux - memory_usage_factor = 1.8 + memory_usage_factor = 3.2 supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): - out = model_base.Chroma(self, model_type=model_base.ModelType.FLOW, device=device) + out = model_base.Chroma(self, device=device) return out def clip_target(self, state_dict={}): From 3d375a153e6d49c7d51986716ef0daa5235d7803 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Tue, 8 Apr 2025 12:56:02 +0200 Subject: [PATCH 15/17] fix mistake --- comfy/supported_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c913545fd..f5c79f702 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.chroma from . import supported_models_base from . import latent_formats @@ -1034,6 +1035,7 @@ class Chroma(supported_models_base.BASE): } sampling_settings = { + "multiplier": 1.0, } latent_format = comfy.latent_formats.Flux @@ -1042,6 +1044,7 @@ class Chroma(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, device=device) return out From 2c2481955d10cf791504cbb425a363bd5c7904e0 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 11 Apr 2025 16:18:57 +0200 Subject: [PATCH 16/17] reduce code duplication --- comfy/ldm/chroma/layers.py | 114 +++++-------------------------------- comfy/ldm/chroma/model.py | 22 ++----- 2 files changed, 20 insertions(+), 116 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 8ad3c72d4..410b7121f 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -5,112 +5,26 @@ import torch from torch import Tensor, nn from .math import attention, rope -import comfy.ops +from comfy.ldm.flux.layers import ( + MLPEmbedder, + RMSNorm, + QKNorm, + SelfAttention, + ModulationOut, +) import comfy.ldm.common_dit -class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - def forward(self, ids: Tensor) -> Tensor: - n_axes = ids.shape[-1] - emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - dim=-3, +class ChromaModulationOut(ModulationOut): + @classmethod + def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: + return cls( + shift=tensor[:, offset : offset + 1, :], + scale=tensor[:, offset + 1 : offset + 2, :], + gate=tensor[:, offset + 2 : offset + 3, :], ) - return emb.unsqueeze(1) - - -def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - t = time_factor * t - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) - - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - if torch.is_floating_point(t): - embedding = embedding.to(t) - return embedding - -class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) - self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) - - def forward(self, x: Tensor) -> Tensor: - return self.out_layer(self.silu(self.in_layer(x))) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) - - def forward(self, x: Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) - - -class QKNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: - q = self.query_norm(q) - k = self.key_norm(k) - return q.to(v), k.to(v) - - -class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - - self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) - - -@dataclass -class ModulationOut: - shift: Tensor - scale: Tensor - gate: Tensor - - -class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): - super().__init__() - self.is_double = double - self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) - - def forward(self, vec: Tensor) -> tuple: - out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) - - return ( - ModulationOut(*out[:3]), - ModulationOut(*out[3:]) if self.is_double else None, - ) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a956ad054..636748fc5 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -7,15 +7,17 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +from comfy.ldm.flux.layers import ( + EmbedND, + timestep_embedding, +) + from .layers import ( DoubleStreamBlock, - EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding, Approximator, - ModulationOut + ChromaModulationOut, ) @@ -39,14 +41,6 @@ class ChromaParams: n_layers: int -class ChromaModulationOut(ModulationOut): - @classmethod - def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: - return cls( - shift=tensor[:, offset : offset + 1, :], - scale=tensor[:, offset + 1 : offset + 2, :], - gate=tensor[:, offset + 2 : offset + 3, :], - ) class Chroma(nn.Module): @@ -77,7 +71,6 @@ class Chroma(nn.Module): self.n_layers = params.n_layers self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) # set as nn identity for now, will overwrite it later. self.distilled_guidance_layer = Approximator( @@ -88,9 +81,6 @@ class Chroma(nn.Module): dtype=dtype, device=device, operations=operations ) - self.guidance_in = ( - MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity() - ) self.double_blocks = nn.ModuleList( [ From cd3d2d5c62a8ee5d2e0009b67ddff661726c61e7 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 11 Apr 2025 16:22:04 +0200 Subject: [PATCH 17/17] remove unused imports --- comfy/ldm/chroma/layers.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 410b7121f..dd0b72f70 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -1,10 +1,7 @@ -import math -from dataclasses import dataclass - import torch from torch import Tensor, nn -from .math import attention, rope +from .math import attention from comfy.ldm.flux.layers import ( MLPEmbedder, RMSNorm, @@ -12,7 +9,6 @@ from comfy.ldm.flux.layers import ( SelfAttention, ModulationOut, ) -import comfy.ldm.common_dit