From 0a4bc660d4e1d60bf027f6317b86cb482c720069 Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Sat, 22 Mar 2025 15:27:11 +0100
Subject: [PATCH 01/17] Upload files for Chroma Implementation

---
 comfy/ldm/chroma/layers.py    | 273 +++++++++++++++++++++++++
 comfy/ldm/chroma/math.py      |  44 ++++
 comfy/ldm/chroma/model.py     | 369 ++++++++++++++++++++++++++++++++++
 comfy/model_base.py           |  55 +++++
 comfy/model_detection.py      |  26 +++
 comfy/sd.py                   |   5 +
 comfy/supported_models.py     |  21 +-
 comfy/text_encoders/chroma.py |  44 ++++
 nodes.py                      |   4 +-
 9 files changed, 839 insertions(+), 2 deletions(-)
 create mode 100644 comfy/ldm/chroma/layers.py
 create mode 100644 comfy/ldm/chroma/math.py
 create mode 100644 comfy/ldm/chroma/model.py
 create mode 100644 comfy/text_encoders/chroma.py

diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
new file mode 100644
index 00000000..606d9688
--- /dev/null
+++ b/comfy/ldm/chroma/layers.py
@@ -0,0 +1,273 @@
+import math
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+
+from .math import attention, rope
+import comfy.ops
+import comfy.ldm.common_dit
+
+
+class EmbedND(nn.Module):
+    def __init__(self, dim: int, theta: int, axes_dim: list):
+        super().__init__()
+        self.dim = dim
+        self.theta = theta
+        self.axes_dim = axes_dim
+
+    def forward(self, ids: Tensor) -> Tensor:
+        n_axes = ids.shape[-1]
+        emb = torch.cat(
+            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+            dim=-3,
+        )
+
+        return emb.unsqueeze(1)
+
+
+def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
+    """
+    Create sinusoidal timestep embeddings.
+    :param t: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an (N, D) Tensor of positional embeddings.
+    """
+    t = time_factor * t
+    half = dim // 2
+    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
+
+    args = t[:, None].float() * freqs[None]
+    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+    if dim % 2:
+        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    if torch.is_floating_point(t):
+        embedding = embedding.to(t)
+    return embedding
+
+class MLPEmbedder(nn.Module):
+    def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
+        self.silu = nn.SiLU()
+        self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
+
+    def forward(self, x: Tensor):
+        return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
+
+
+class QKNorm(torch.nn.Module):
+    def __init__(self, dim: int, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
+        self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
+        q = self.query_norm(q)
+        k = self.key_norm(k)
+        return q.to(v), k.to(v)
+
+
+class SelfAttention(nn.Module):
+    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+
+        self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
+        self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
+        self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
+
+
+@dataclass
+class ModulationOut:
+    shift: Tensor
+    scale: Tensor
+    gate: Tensor
+
+
+class Modulation(nn.Module):
+    def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.is_double = double
+        self.multiplier = 6 if double else 3
+        self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
+
+    def forward(self, vec: Tensor) -> tuple:
+        out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
+
+        return (
+            ModulationOut(*out[:3]),
+            ModulationOut(*out[3:]) if self.is_double else None,
+        )
+
+
+
+class Approximator(nn.Module):
+    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
+        self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+        self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+        self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
+
+    @property
+    def device(self):
+        # Get the device of the module (assumes all parameters are on the same device)
+        return next(self.parameters()).device
+    
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.in_proj(x)
+
+        for layer, norms in zip(self.layers, self.norms):
+            x = x + layer(norms(x))
+
+        x = self.out_proj(x)
+
+        return x
+
+
+class DoubleStreamBlock(nn.Module):
+    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
+        super().__init__()
+
+        mlp_hidden_dim = int(hidden_size * mlp_ratio)
+        self.num_heads = num_heads
+        self.hidden_size = hidden_size
+        self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
+
+        self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+        self.img_mlp = nn.Sequential(
+            operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+            nn.GELU(approximate="tanh"),
+            operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+        )
+
+        self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
+
+        self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+        self.txt_mlp = nn.Sequential(
+            operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+            nn.GELU(approximate="tanh"),
+            operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+        )
+        self.flipped_img_txt = flipped_img_txt
+
+    def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
+        (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
+
+        # prepare image for attention
+        img_modulated = self.img_norm1(img)
+        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+        img_qkv = self.img_attn.qkv(img_modulated)
+        img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
+
+        # prepare txt for attention
+        txt_modulated = self.txt_norm1(txt)
+        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+        txt_qkv = self.txt_attn.qkv(txt_modulated)
+        txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
+
+        # run actual attention
+        attn = attention(torch.cat((txt_q, img_q), dim=2),
+                         torch.cat((txt_k, img_k), dim=2),
+                         torch.cat((txt_v, img_v), dim=2),
+                         pe=pe, mask=attn_mask)
+
+        txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
+
+        # calculate the img bloks
+        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
+        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
+
+        # calculate the txt bloks
+        txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
+        txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
+
+        if txt.dtype == torch.float16:
+            txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
+
+        return img, txt
+
+
+class SingleStreamBlock(nn.Module):
+    """
+    A DiT block with parallel linear layers as described in
+    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+    """
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qk_scale: float = None,
+        dtype=None,
+        device=None,
+        operations=None
+    ):
+        super().__init__()
+        self.hidden_dim = hidden_size
+        self.num_heads = num_heads
+        head_dim = hidden_size // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+        # qkv and mlp_in
+        self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
+        # proj and mlp_out
+        self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
+
+        self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
+
+        self.hidden_size = hidden_size
+        self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+        self.mlp_act = nn.GELU(approximate="tanh")
+
+    def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
+        mod = vec
+        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
+        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+
+        q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k = self.norm(q, k, v)
+
+        # compute attention
+        attn = attention(q, k, v, pe=pe, mask=attn_mask)
+        # compute activation in mlp stream, cat again and run second linear layer
+        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
+        x += mod.gate * output
+        if x.dtype == torch.float16:
+            x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
+        return x
+
+
+class LastLayer(nn.Module):
+    def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
+        super().__init__()
+        self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+        self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
+
+    def forward(self, x: Tensor, vec: Tensor) -> Tensor:
+        shift, scale = vec
+        shift = shift.squeeze(1)
+        scale = scale.squeeze(1)
+        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+        x = self.linear(x)
+        return x
diff --git a/comfy/ldm/chroma/math.py b/comfy/ldm/chroma/math.py
new file mode 100644
index 00000000..36b67931
--- /dev/null
+++ b/comfy/ldm/chroma/math.py
@@ -0,0 +1,44 @@
+import torch
+from einops import rearrange
+from torch import Tensor
+
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+
+
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
+    q_shape = q.shape
+    k_shape = k.shape
+
+    q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
+    k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
+    q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
+    k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
+
+    heads = q.shape[1]
+    x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
+    return x
+
+
+def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
+    assert dim % 2 == 0
+    if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
+        device = torch.device("cpu")
+    else:
+        device = pos.device
+
+    scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
+    omega = 1.0 / (theta**scale)
+    out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
+    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
+    return out.to(dtype=torch.float32, device=pos.device)
+
+
+def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
+    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
new file mode 100644
index 00000000..190e1900
--- /dev/null
+++ b/comfy/ldm/chroma/model.py
@@ -0,0 +1,369 @@
+#Original code can be found on: https://github.com/black-forest-labs/flux
+
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+from einops import rearrange, repeat
+import comfy.ldm.common_dit
+from .common import pad_to_patch_size, rms_norm
+
+from .layers import (
+    DoubleStreamBlock,
+    EmbedND,
+    LastLayer,
+    MLPEmbedder,
+    SingleStreamBlock,
+    timestep_embedding,
+    Approximator,
+    ModulationOut
+)
+
+
+@dataclass
+class ChromaParams:
+    in_channels: int
+    out_channels: int
+    context_in_dim: int
+    hidden_size: int
+    mlp_ratio: float
+    num_heads: int
+    depth: int
+    depth_single_blocks: int
+    axes_dim: list
+    theta: int
+    patch_size: int
+    qkv_bias: bool
+    in_dim: int
+    out_dim: int
+    hidden_dim: int
+    n_layers: int
+
+
+class Chroma(nn.Module):
+    """
+    Transformer model for flow matching on sequences.
+    """
+
+    def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
+        super().__init__()
+        self.dtype = dtype
+        params = ChromaParams(**kwargs)
+        self.params = params
+        self.patch_size = params.patch_size
+        self.in_channels = params.in_channels
+        self.out_channels = params.out_channels
+        if params.hidden_size % params.num_heads != 0:
+            raise ValueError(
+                f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+            )
+        pe_dim = params.hidden_size // params.num_heads
+        if sum(params.axes_dim) != pe_dim:
+            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+        self.hidden_size = params.hidden_size
+        self.num_heads = params.num_heads
+        self.in_dim = params.in_dim
+        self.out_dim = params.out_dim
+        self.hidden_dim = params.hidden_dim
+        self.n_layers = params.n_layers
+        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+        self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
+        self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
+        self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
+        # set as nn identity for now, will overwrite it later.
+        self.distilled_guidance_layer = Approximator(
+                    in_dim=self.in_dim,
+                    hidden_dim=self.hidden_dim,
+                    out_dim=self.out_dim,
+                    n_layers=self.n_layers,
+                    dtype=dtype, device=device, operations=operations
+                )
+
+        self.guidance_in = (
+            MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity()
+        )
+
+        self.double_blocks = nn.ModuleList(
+            [
+                DoubleStreamBlock(
+                    self.hidden_size,
+                    self.num_heads,
+                    mlp_ratio=params.mlp_ratio,
+                    qkv_bias=params.qkv_bias,
+                    dtype=dtype, device=device, operations=operations
+                )
+                for _ in range(params.depth)
+            ]
+        )
+
+        self.single_blocks = nn.ModuleList(
+            [
+                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
+                for _ in range(params.depth_single_blocks)
+            ]
+        )
+
+        if final_layer:
+            self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
+
+        self.skip_mmdit = []
+        self.skip_dit = []
+        self.lite = False
+    @staticmethod
+    def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19):
+        """
+        Distributes slices of the tensor into the block_dict as ModulationOut objects.
+
+        Args:
+            tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
+        """
+        batch_size, vectors, dim = tensor.shape
+
+        block_dict = {}
+
+        # HARD CODED VALUES! lookup table for the generated vectors
+        # Add 38 single mod blocks
+        for i in range(single_block_count):
+            key = f"single_blocks.{i}.modulation.lin"
+            block_dict[key] = None
+
+        # Add 19 image double blocks
+        for i in range(double_blocks_count):
+            key = f"double_blocks.{i}.img_mod.lin"
+            block_dict[key] = None
+
+        # Add 19 text double blocks
+        for i in range(double_blocks_count):
+            key = f"double_blocks.{i}.txt_mod.lin"
+            block_dict[key] = None
+
+        # Add the final layer
+        block_dict["final_layer.adaLN_modulation.1"] = None
+        # # 6.2b version
+        # block_dict["lite_double_blocks.4.img_mod.lin"] = None
+        # block_dict["lite_double_blocks.4.txt_mod.lin"] = None
+
+
+        idx = 0  # Index to keep track of the vector slices
+
+        for key in block_dict.keys():
+            if "single_blocks" in key:
+                # Single block: 1 ModulationOut
+                block_dict[key] = ModulationOut(
+                    shift=tensor[:, idx:idx+1, :],
+                    scale=tensor[:, idx+1:idx+2, :],
+                    gate=tensor[:, idx+2:idx+3, :]
+                )
+                idx += 3  # Advance by 3 vectors
+
+            elif "img_mod" in key:
+                # Double block: List of 2 ModulationOut
+                double_block = []
+                for _ in range(2):  # Create 2 ModulationOut objects
+                    double_block.append(
+                        ModulationOut(
+                            shift=tensor[:, idx:idx+1, :],
+                            scale=tensor[:, idx+1:idx+2, :],
+                            gate=tensor[:, idx+2:idx+3, :]
+                        )
+                    )
+                    idx += 3  # Advance by 3 vectors per ModulationOut
+                block_dict[key] = double_block
+
+            elif "txt_mod" in key:
+                # Double block: List of 2 ModulationOut
+                double_block = []
+                for _ in range(2):  # Create 2 ModulationOut objects
+                    double_block.append(
+                        ModulationOut(
+                            shift=tensor[:, idx:idx+1, :],
+                            scale=tensor[:, idx+1:idx+2, :],
+                            gate=tensor[:, idx+2:idx+3, :]
+                        )
+                    )
+                    idx += 3  # Advance by 3 vectors per ModulationOut
+                block_dict[key] = double_block
+
+            elif "final_layer" in key:
+                # Final layer: 1 ModulationOut
+                block_dict[key] = [
+                    tensor[:, idx:idx+1, :],
+                    tensor[:, idx+1:idx+2, :],
+                ]
+                idx += 2  # Advance by 2 vectors
+
+            # elif "lite_double_blocks.4.img_mod" in key:
+            #     # Double block: List of 2 ModulationOut
+            #     double_block = []
+            #     for _ in range(2):  # Create 2 ModulationOut objects
+            #         double_block.append(
+            #             ModulationOut(
+            #                 shift=tensor[:, idx:idx+1, :],
+            #                 scale=tensor[:, idx+1:idx+2, :],
+            #                 gate=tensor[:, idx+2:idx+3, :]
+            #             )
+            #         )
+            #         idx += 3  # Advance by 3 vectors per ModulationOut
+            #     block_dict[key] = double_block
+
+            # elif "lite_double_blocks.4.txt_mod" in key:
+            #     # Double block: List of 2 ModulationOut
+            #     double_block = []
+            #     for _ in range(2):  # Create 2 ModulationOut objects
+            #         double_block.append(
+            #             ModulationOut(
+            #                 shift=tensor[:, idx:idx+1, :],
+            #                 scale=tensor[:, idx+1:idx+2, :],
+            #                 gate=tensor[:, idx+2:idx+3, :]
+            #             )
+            #         )
+            #         idx += 3  # Advance by 3 vectors per ModulationOut
+            #     block_dict[key] = double_block
+
+        return block_dict
+
+    def forward_orig(
+        self,
+        img: Tensor,
+        img_ids: Tensor,
+        txt: Tensor,
+        txt_ids: Tensor,
+        timesteps: Tensor,
+        guidance: Tensor = None,
+        control = None,
+        transformer_options={},
+        attn_mask: Tensor = None,
+    ) -> Tensor:
+        patches_replace = transformer_options.get("patches_replace", {})
+        if img.ndim != 3 or txt.ndim != 3:
+            raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+        # running on sequences img
+        img = self.img_in(img)
+
+        # distilled vector guidance
+        mod_index_length = 344
+        distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
+        # guidance = guidance *
+        distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
+
+        # get all modulation index
+        modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
+        # we need to broadcast the modulation index here so each batch has all of the index
+        modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
+        # and we need to broadcast timestep and guidance along too
+        timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
+        # then and only then we could concatenate it together
+        input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
+
+        mod_vectors = self.distilled_guidance_layer(input_vec)
+
+        mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19)
+
+        txt = self.txt_in(txt)
+
+        ids = torch.cat((txt_ids, img_ids), dim=1)
+        pe = self.pe_embedder(ids)
+
+        blocks_replace = patches_replace.get("dit", {})
+        for i, block in enumerate(self.double_blocks):
+            if i not in self.skip_mmdit:
+                guidance_index = i
+                # if lite we change block 4 guidance with lite guidance
+                # and offset the guidance by 11 blocks after block 4
+                # if self.lite and i == 4:
+                #     img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
+                #     txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
+                # elif self.lite and i > 4:
+                #     guidance_index = i + 11 
+                #     img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
+                #     txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
+                # else:
+                img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
+                txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
+                double_mod = [img_mod, txt_mod]
+
+                if ("double_block", i) in blocks_replace:
+                    def block_wrap(args):
+                        out = {}
+                        out["img"], out["txt"] = block(img=args["img"],
+                                                       txt=args["txt"],
+                                                       vec=args["vec"],
+                                                       pe=args["pe"],
+                                                       attn_mask=args.get("attn_mask"))
+                        return out
+
+                    out = blocks_replace[("double_block", i)]({"img": img,
+                                                               "txt": txt,
+                                                               "vec": double_mod,
+                                                               "pe": pe,
+                                                               "attn_mask": attn_mask},
+                                                              {"original_block": block_wrap})
+                    txt = out["txt"]
+                    img = out["img"]
+                else:
+                    img, txt = block(img=img,
+                                     txt=txt,
+                                     vec=double_mod,
+                                     pe=pe,
+                                     attn_mask=attn_mask)
+
+                if control is not None: # Controlnet
+                    control_i = control.get("input")
+                    if i < len(control_i):
+                        add = control_i[i]
+                        if add is not None:
+                            img += add
+
+        img = torch.cat((txt, img), 1)
+
+        for i, block in enumerate(self.single_blocks):
+            if i not in self.skip_dit:
+                single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
+                if ("single_block", i) in blocks_replace:
+                    def block_wrap(args):
+                        out = {}
+                        out["img"] = block(args["img"],
+                                           vec=args["vec"],
+                                           pe=args["pe"],
+                                           attn_mask=args.get("attn_mask"))
+                        return out
+
+                    out = blocks_replace[("single_block", i)]({"img": img,
+                                                               "vec": single_mod,
+                                                               "pe": pe,
+                                                               "attn_mask": attn_mask},
+                                                              {"original_block": block_wrap})
+                    img = out["img"]
+                else:
+                    img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
+
+                if control is not None: # Controlnet
+                    control_o = control.get("output")
+                    if i < len(control_o):
+                        add = control_o[i]
+                        if add is not None:
+                            img[:, txt.shape[1] :, ...] += add
+
+        img = img[:, txt.shape[1] :, ...]
+        final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
+        img = self.final_layer(img, vec=final_mod)  # (N, T, patch_size ** 2 * out_channels)
+        return img
+
+    def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
+        bs, c, h, w = x.shape
+        patch_size = 2
+        x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
+
+        img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
+
+        h_len = ((h + (patch_size // 2)) // patch_size)
+        w_len = ((w + (patch_size // 2)) // patch_size)
+        img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
+        img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+        img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+        img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+        txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+        out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
+        return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
diff --git a/comfy/model_base.py b/comfy/model_base.py
index eec70d5d..05e242b8 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model
 import comfy.ldm.lumina.model
 import comfy.ldm.wan.model
 import comfy.ldm.hunyuan3d.model
+import comfy.ldm.chroma.model
 
 import comfy.model_management
 import comfy.patcher_extension
@@ -1046,3 +1047,57 @@ class Hunyuan3Dv2(BaseModel):
         if guidance is not None:
             out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
         return out
+
+class Chroma(BaseModel):
+    chroma_model_mode=False
+
+    def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+        super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
+
+    def concat_cond(self, **kwargs):
+        try:
+            #Handle Flux control loras dynamically changing the img_in weight.
+            num_channels = self.diffusion_model.img_in.weight.shape[1]
+        except:
+            #Some cases like tensorrt might not have the weights accessible
+            num_channels = self.model_config.unet_config["in_channels"]
+
+        out_channels = self.model_config.unet_config["out_channels"]
+
+        if num_channels <= out_channels:
+            return None
+
+        image = kwargs.get("concat_latent_image", None)
+        noise = kwargs.get("noise", None)
+        device = kwargs["device"]
+
+        if image is None:
+            image = torch.zeros_like(noise)
+
+        image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+        image = utils.resize_to_batch_size(image, noise.shape[0])
+        image = self.process_latent_in(image)
+        if num_channels <= out_channels * 2:
+            return image
+
+        #inpaint model
+        mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+        if mask is None:
+            mask = torch.ones_like(noise)[:, :1]
+
+        mask = torch.mean(mask, dim=1, keepdim=True)
+        mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
+        mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
+        mask = utils.resize_to_batch_size(mask, noise.shape[0])
+        return torch.cat((image, mask), dim=1)
+
+
+    def extra_conds(self, **kwargs):
+        out = super().extra_conds(**kwargs)
+        cross_attn = kwargs.get("cross_attn", None)
+        if cross_attn is not None:
+            out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+        # upscale the attention mask, since now we
+        guidance = 0.0
+        out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
+        return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 4217f583..a3d36648 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -154,6 +154,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
         dit_config["guidance_embed"] = len(guidance_keys) > 0
         return dit_config
 
+    if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
+        dit_config = {}
+        dit_config["image_model"] = "chroma"
+        dit_config["depth"] = 48
+        dit_config["in_channels"] = 64
+        patch_size = 2
+        dit_config["patch_size"] = patch_size
+        in_key = "{}img_in.weight".format(key_prefix)
+        if in_key in state_dict_keys:
+            dit_config["in_channels"] = state_dict[in_key].shape[1]
+        dit_config["out_channels"] = 64
+        dit_config["context_in_dim"] = 4096
+        dit_config["hidden_size"] = 3072
+        dit_config["mlp_ratio"] = 4.0
+        dit_config["num_heads"] = 24
+        dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
+        dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
+        dit_config["axes_dim"] = [16, 56, 56]
+        dit_config["theta"] = 10000
+        dit_config["qkv_bias"] = True
+        dit_config["in_dim"] = 64
+        dit_config["out_dim"] = 3072
+        dit_config["hidden_dim"] = 5120
+        dit_config["n_layers"] = 5
+        return dit_config
+
     if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
         dit_config = {}
         dit_config["image_model"] = "flux"
diff --git a/comfy/sd.py b/comfy/sd.py
index d096f496..895f4a1a 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video
 import comfy.text_encoders.cosmos
 import comfy.text_encoders.lumina2
 import comfy.text_encoders.wan
+import comfy.text_encoders.chroma
 
 import comfy.model_patcher
 import comfy.lora
@@ -700,6 +701,7 @@ class CLIPType(Enum):
     COSMOS = 11
     LUMINA2 = 12
     WAN = 13
+    CHROMA = 14
 
 
 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -808,6 +810,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
                 clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
                 clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
                 tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+            elif clip_type == CLIPType.CHROMA:
+                clip_target.clip = comfy.text_encoders.chroma.chroma_te(**t5xxl_detect(clip_data))
+                clip_target.tokenizer = comfy.text_encoders.chroma.ChromaT5Tokenizer
             else: #CLIPType.MOCHI
                 clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
                 clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index fad00d35..1520b58c 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1013,6 +1013,25 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
 
     latent_format = latent_formats.Hunyuan3Dv2mini
 
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
+class Chroma(supported_models_base.BASE):
+    unet_config = {
+        "image_model": "chroma",
+    }
+
+    unet_extra_config = {
+    }
+
+    sampling_settings = {
+        "multiplier": 1.0,
+        "shift": 1.0,
+    }
+    latent_format = comfy.latent_formats.Flux
+    memory_usage_factor = 2.8
+
+    def get_model(self, state_dict, prefix="", device=None):
+        out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
+        return out
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
 
 models += [SVD_img2vid]
diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py
new file mode 100644
index 00000000..cf10bffc
--- /dev/null
+++ b/comfy/text_encoders/chroma.py
@@ -0,0 +1,44 @@
+from comfy import sd1_clip
+import comfy.text_encoders.t5
+import os
+from transformers import T5TokenizerFast
+
+
+class T5XXLModel(sd1_clip.SDClipModel):
+    def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
+        textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
+        t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
+        if t5xxl_scaled_fp8 is not None:
+            model_options = model_options.copy()
+            model_options["scaled_fp8"] = t5xxl_scaled_fp8
+        attention_mask = True
+
+        super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+
+class ChromaT5XXL(sd1_clip.SD1ClipModel):
+    def __init__(self, device="cpu", dtype=None, model_options={}):
+        super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
+
+
+class T5XXLTokenizer(sd1_clip.SDTokenizer):
+    def __init__(self, embedding_directory=None, tokenizer_data={}):
+        tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
+        super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
+
+
+class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer):
+    def __init__(self, embedding_directory=None, tokenizer_data={}):
+        super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
+
+
+def chroma_te(dtype_t5=None, t5xxl_scaled_fp8=None):
+    class ChromaTEModel_(ChromaT5XXL):
+        def __init__(self, device="cpu", dtype=None, model_options={}):
+            if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+                model_options = model_options.copy()
+                model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+            if dtype is None:
+                dtype = dtype_t5
+            super().__init__(device=device, dtype=dtype, model_options=model_options)
+    return ChromaTEModel_
diff --git a/nodes.py b/nodes.py
index 27ef743b..f506c8a9 100644
--- a/nodes.py
+++ b/nodes.py
@@ -915,7 +915,7 @@ class CLIPLoader:
     @classmethod
     def INPUT_TYPES(s):
         return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
-                              "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
+                              "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "chroma"], ),
                               },
                 "optional": {
                               "device": (["default", "cpu"], {"advanced": True}),
@@ -946,6 +946,8 @@ class CLIPLoader:
             clip_type = comfy.sd.CLIPType.LUMINA2
         elif type == "wan":
             clip_type = comfy.sd.CLIPType.WAN
+        elif type == "chroma":
+            clip_type = comfy.sd.CLIPType.CHROMA
         else:
             clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
 

From 79f460150e253373c0d5cf5465e70ac0cba42d98 Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Sat, 22 Mar 2025 16:16:23 +0100
Subject: [PATCH 02/17] Remove trailing whitespace

---
 comfy/ldm/chroma/model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index 190e1900..624f65f2 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -275,7 +275,7 @@ class Chroma(nn.Module):
                 #     img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
                 #     txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
                 # elif self.lite and i > 4:
-                #     guidance_index = i + 11 
+                #     guidance_index = i + 11
                 #     img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
                 #     txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
                 # else:

From 2710f77218af804449b26702abbfdff9550c293c Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Sat, 22 Mar 2025 16:18:17 +0100
Subject: [PATCH 03/17] trim more trailing whitespace..oops

---
 comfy/ldm/chroma/layers.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 606d9688..8ad3c72d 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -126,7 +126,7 @@ class Approximator(nn.Module):
     def device(self):
         # Get the device of the module (assumes all parameters are on the same device)
         return next(self.parameters()).device
-    
+
     def forward(self, x: Tensor) -> Tensor:
         x = self.in_proj(x)
 

From ca73500269b7e7bd5ae2e301386f110bb608f215 Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Sat, 22 Mar 2025 16:20:08 +0100
Subject: [PATCH 04/17] remove unused imports

---
 comfy/ldm/chroma/model.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index 624f65f2..b3b03dcd 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -6,7 +6,6 @@ import torch
 from torch import Tensor, nn
 from einops import rearrange, repeat
 import comfy.ldm.common_dit
-from .common import pad_to_patch_size, rms_norm
 
 from .layers import (
     DoubleStreamBlock,

From fda35f37e5b55f616c13167e481183bf64237af7 Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Sat, 22 Mar 2025 16:49:16 +0100
Subject: [PATCH 05/17] Add supported_inference_dtypes

---
 comfy/supported_models.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 1520b58c..da5f3abc 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1027,6 +1027,7 @@ class Chroma(supported_models_base.BASE):
     }
     latent_format = comfy.latent_formats.Flux
     memory_usage_factor = 2.8
+    supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
 
     def get_model(self, state_dict, prefix="", device=None):
         out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)

From 9fa34e72163af560c76850ba38ee54229a9ddf55 Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Sun, 23 Mar 2025 06:14:50 +0100
Subject: [PATCH 06/17] Set min_length to 0 and remove attention_mask=True

---
 comfy/text_encoders/chroma.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py
index cf10bffc..30b5c11c 100644
--- a/comfy/text_encoders/chroma.py
+++ b/comfy/text_encoders/chroma.py
@@ -11,7 +11,6 @@ class T5XXLModel(sd1_clip.SDClipModel):
         if t5xxl_scaled_fp8 is not None:
             model_options = model_options.copy()
             model_options["scaled_fp8"] = t5xxl_scaled_fp8
-        attention_mask = True
 
         super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
 
@@ -24,7 +23,7 @@ class ChromaT5XXL(sd1_clip.SD1ClipModel):
 class T5XXLTokenizer(sd1_clip.SDTokenizer):
     def __init__(self, embedding_directory=None, tokenizer_data={}):
         tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
-        super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
+        super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=0)
 
 
 class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer):

From fe6e1fa44f792750705e3c42fb0a4d5ed0f5c5a2 Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Tue, 25 Mar 2025 07:21:08 +0100
Subject: [PATCH 07/17] Set min_length to 1

---
 comfy/text_encoders/chroma.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py
index 30b5c11c..5ee5d57b 100644
--- a/comfy/text_encoders/chroma.py
+++ b/comfy/text_encoders/chroma.py
@@ -23,7 +23,7 @@ class ChromaT5XXL(sd1_clip.SD1ClipModel):
 class T5XXLTokenizer(sd1_clip.SDTokenizer):
     def __init__(self, embedding_directory=None, tokenizer_data={}):
         tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
-        super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=0)
+        super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1)
 
 
 class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer):

From f04b502ab616318934de6b3a0d95b65974538978 Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Tue, 25 Mar 2025 21:38:11 +0100
Subject: [PATCH 08/17] get_mdulations added from blepping and minor changes

---
 comfy/ldm/chroma/model.py | 171 ++++++++++----------------------------
 comfy/model_base.py       |  11 ++-
 comfy/supported_models.py |  10 ++-
 3 files changed, 60 insertions(+), 132 deletions(-)

diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index b3b03dcd..a956ad05 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -39,6 +39,16 @@ class ChromaParams:
     n_layers: int
 
 
+class ChromaModulationOut(ModulationOut):
+    @classmethod
+    def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
+        return cls(
+            shift=tensor[:, offset : offset + 1, :],
+            scale=tensor[:, offset + 1 : offset + 2, :],
+            gate=tensor[:, offset + 2 : offset + 3, :],
+        )
+
+
 class Chroma(nn.Module):
     """
     Transformer model for flow matching on sequences.
@@ -108,118 +118,34 @@ class Chroma(nn.Module):
         self.skip_mmdit = []
         self.skip_dit = []
         self.lite = False
-    @staticmethod
-    def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19):
-        """
-        Distributes slices of the tensor into the block_dict as ModulationOut objects.
 
-        Args:
-            tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
-        """
-        batch_size, vectors, dim = tensor.shape
+    def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
+        # This function slices up the modulations tensor which has the following layout:
+        #   single     : num_single_blocks * 3 elements
+        #   double_img : num_double_blocks * 6 elements
+        #   double_txt : num_double_blocks * 6 elements
+        #   final      : 2 elements
+        if block_type == "final":
+            return (tensor[:, -2:-1, :], tensor[:, -1:, :])
+        single_block_count = self.params.depth_single_blocks
+        double_block_count = self.params.depth
+        offset = 3 * idx
+        if block_type == "single":
+            return ChromaModulationOut.from_offset(tensor, offset)
+        # Double block modulations are 6 elements so we double 3 * idx.
+        offset *= 2
+        if block_type in {"double_img", "double_txt"}:
+            # Advance past the single block modulations.
+            offset += 3 * single_block_count
+            if block_type == "double_txt":
+                # Advance past the double block img modulations.
+                offset += 6 * double_block_count
+            return (
+                ChromaModulationOut.from_offset(tensor, offset),
+                ChromaModulationOut.from_offset(tensor, offset + 3),
+            )
+        raise ValueError("Bad block_type")
 
-        block_dict = {}
-
-        # HARD CODED VALUES! lookup table for the generated vectors
-        # Add 38 single mod blocks
-        for i in range(single_block_count):
-            key = f"single_blocks.{i}.modulation.lin"
-            block_dict[key] = None
-
-        # Add 19 image double blocks
-        for i in range(double_blocks_count):
-            key = f"double_blocks.{i}.img_mod.lin"
-            block_dict[key] = None
-
-        # Add 19 text double blocks
-        for i in range(double_blocks_count):
-            key = f"double_blocks.{i}.txt_mod.lin"
-            block_dict[key] = None
-
-        # Add the final layer
-        block_dict["final_layer.adaLN_modulation.1"] = None
-        # # 6.2b version
-        # block_dict["lite_double_blocks.4.img_mod.lin"] = None
-        # block_dict["lite_double_blocks.4.txt_mod.lin"] = None
-
-
-        idx = 0  # Index to keep track of the vector slices
-
-        for key in block_dict.keys():
-            if "single_blocks" in key:
-                # Single block: 1 ModulationOut
-                block_dict[key] = ModulationOut(
-                    shift=tensor[:, idx:idx+1, :],
-                    scale=tensor[:, idx+1:idx+2, :],
-                    gate=tensor[:, idx+2:idx+3, :]
-                )
-                idx += 3  # Advance by 3 vectors
-
-            elif "img_mod" in key:
-                # Double block: List of 2 ModulationOut
-                double_block = []
-                for _ in range(2):  # Create 2 ModulationOut objects
-                    double_block.append(
-                        ModulationOut(
-                            shift=tensor[:, idx:idx+1, :],
-                            scale=tensor[:, idx+1:idx+2, :],
-                            gate=tensor[:, idx+2:idx+3, :]
-                        )
-                    )
-                    idx += 3  # Advance by 3 vectors per ModulationOut
-                block_dict[key] = double_block
-
-            elif "txt_mod" in key:
-                # Double block: List of 2 ModulationOut
-                double_block = []
-                for _ in range(2):  # Create 2 ModulationOut objects
-                    double_block.append(
-                        ModulationOut(
-                            shift=tensor[:, idx:idx+1, :],
-                            scale=tensor[:, idx+1:idx+2, :],
-                            gate=tensor[:, idx+2:idx+3, :]
-                        )
-                    )
-                    idx += 3  # Advance by 3 vectors per ModulationOut
-                block_dict[key] = double_block
-
-            elif "final_layer" in key:
-                # Final layer: 1 ModulationOut
-                block_dict[key] = [
-                    tensor[:, idx:idx+1, :],
-                    tensor[:, idx+1:idx+2, :],
-                ]
-                idx += 2  # Advance by 2 vectors
-
-            # elif "lite_double_blocks.4.img_mod" in key:
-            #     # Double block: List of 2 ModulationOut
-            #     double_block = []
-            #     for _ in range(2):  # Create 2 ModulationOut objects
-            #         double_block.append(
-            #             ModulationOut(
-            #                 shift=tensor[:, idx:idx+1, :],
-            #                 scale=tensor[:, idx+1:idx+2, :],
-            #                 gate=tensor[:, idx+2:idx+3, :]
-            #             )
-            #         )
-            #         idx += 3  # Advance by 3 vectors per ModulationOut
-            #     block_dict[key] = double_block
-
-            # elif "lite_double_blocks.4.txt_mod" in key:
-            #     # Double block: List of 2 ModulationOut
-            #     double_block = []
-            #     for _ in range(2):  # Create 2 ModulationOut objects
-            #         double_block.append(
-            #             ModulationOut(
-            #                 shift=tensor[:, idx:idx+1, :],
-            #                 scale=tensor[:, idx+1:idx+2, :],
-            #                 gate=tensor[:, idx+2:idx+3, :]
-            #             )
-            #         )
-            #         idx += 3  # Advance by 3 vectors per ModulationOut
-            #     block_dict[key] = double_block
-
-        return block_dict
 
     def forward_orig(
         self,
@@ -257,8 +183,6 @@ class Chroma(nn.Module):
 
         mod_vectors = self.distilled_guidance_layer(input_vec)
 
-        mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19)
-
         txt = self.txt_in(txt)
 
         ids = torch.cat((txt_ids, img_ids), dim=1)
@@ -267,21 +191,10 @@ class Chroma(nn.Module):
         blocks_replace = patches_replace.get("dit", {})
         for i, block in enumerate(self.double_blocks):
             if i not in self.skip_mmdit:
-                guidance_index = i
-                # if lite we change block 4 guidance with lite guidance
-                # and offset the guidance by 11 blocks after block 4
-                # if self.lite and i == 4:
-                #     img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
-                #     txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
-                # elif self.lite and i > 4:
-                #     guidance_index = i + 11
-                #     img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
-                #     txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
-                # else:
-                img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
-                txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
-                double_mod = [img_mod, txt_mod]
-
+                double_mod = (
+                    self.get_modulations(mod_vectors, "double_img", idx=i),
+                    self.get_modulations(mod_vectors, "double_txt", idx=i),
+                )
                 if ("double_block", i) in blocks_replace:
                     def block_wrap(args):
                         out = {}
@@ -318,7 +231,7 @@ class Chroma(nn.Module):
 
         for i, block in enumerate(self.single_blocks):
             if i not in self.skip_dit:
-                single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
+                single_mod = self.get_modulations(mod_vectors, "single", idx=i)
                 if ("single_block", i) in blocks_replace:
                     def block_wrap(args):
                         out = {}
@@ -345,7 +258,7 @@ class Chroma(nn.Module):
                             img[:, txt.shape[1] :, ...] += add
 
         img = img[:, txt.shape[1] :, ...]
-        final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
+        final_mod = self.get_modulations(mod_vectors, "final")
         img = self.final_layer(img, vec=final_mod)  # (N, T, patch_size ** 2 * out_channels)
         return img
 
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 05e242b8..13349f71 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel):
         return out
 
 class Chroma(BaseModel):
-    chroma_model_mode=False
-
     def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
         super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
 
@@ -1098,6 +1096,15 @@ class Chroma(BaseModel):
         if cross_attn is not None:
             out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
         # upscale the attention mask, since now we
+        attention_mask = kwargs.get("attention_mask", None)
+        if attention_mask is not None:
+            shape = kwargs["noise"].shape
+            mask_ref_size = kwargs["attention_mask_img_shape"]
+            # the model will pad to the patch size, and then divide
+            # essentially dividing and rounding up
+            (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
+            attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
+            out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
         guidance = 0.0
         out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
         return out
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index da5f3abc..c113a023 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE):
         "multiplier": 1.0,
         "shift": 1.0,
     }
+
     latent_format = comfy.latent_formats.Flux
-    memory_usage_factor = 2.8
+
+    memory_usage_factor = 1.8
+
     supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
 
     def get_model(self, state_dict, prefix="", device=None):
         out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
         return out
 
+    def clip_target(self, state_dict={}):
+        pref = self.text_encoder_key_prefix[0]
+        t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
+        return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
+
 models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
 
 models += [SVD_img2vid]

From b7da8e2bc103e6cf32ba056e1f3fb94bb7f97508 Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Thu, 27 Mar 2025 03:08:24 +0100
Subject: [PATCH 09/17] Add lora conversion if statement in lora.py

---
 comfy/lora.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/comfy/lora.py b/comfy/lora.py
index bc9f3022..f466a5ae 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -378,7 +378,7 @@ def model_lora_keys_unet(model, key_map={}):
                 key_lora = k[len("diffusion_model."):-len(".weight")]
                 key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
 
-    if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
+    if isinstance(model, comfy.model_base.Flux) or isinstance(model, comfy.model_base.Chroma): #Diffusers lora Flux or a diffusers lora Chroma
         diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
         for k in diffusers_keys:
             if k.endswith(".weight"):

From bf339e826559d4e2c38399d29d3759af4da3696d Mon Sep 17 00:00:00 2001
From: Silver <65376327+silveroxides@users.noreply.github.com>
Date: Thu, 27 Mar 2025 18:29:51 +0100
Subject: [PATCH 10/17] Update supported_models.py

---
 comfy/supported_models.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index c113a023..0fe97cb4 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V):
     unet_config = {
         "image_model": "wan2.1",
         "model_type": "i2v",
+        "in_dim": 36,
     }
 
     def get_model(self, state_dict, prefix="", device=None):
         out = model_base.WAN21(self, image_to_video=True, device=device)
         return out
 
+class WAN21_FunControl2V(WAN21_T2V):
+    unet_config = {
+        "image_model": "wan2.1",
+        "model_type": "i2v",
+        "in_dim": 48,
+    }
+
+    def get_model(self, state_dict, prefix="", device=None):
+        out = model_base.WAN21(self, image_to_video=False, device=device)
+        return out
+
 class Hunyuan3Dv2(supported_models_base.BASE):
     unet_config = {
         "image_model": "hunyuan3d2",
@@ -1041,6 +1053,6 @@ class Chroma(supported_models_base.BASE):
         t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
         return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
 
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
 
 models += [SVD_img2vid]

From 2378c63ba9c8a900bdf70f3ec4b4275a800393a5 Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Thu, 27 Mar 2025 18:43:58 +0100
Subject: [PATCH 11/17] update model_base.py

---
 comfy/model_base.py | 24 ++++++++++++++++--------
 1 file changed, 16 insertions(+), 8 deletions(-)

diff --git a/comfy/model_base.py b/comfy/model_base.py
index 13349f71..9d73fa46 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -993,7 +993,8 @@ class WAN21(BaseModel):
 
     def concat_cond(self, **kwargs):
         noise = kwargs.get("noise", None)
-        if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
+        extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
+        if extra_channels == 0:
             return None
 
         image = kwargs.get("concat_latent_image", None)
@@ -1001,23 +1002,30 @@ class WAN21(BaseModel):
 
         if image is None:
             image = torch.zeros_like(noise)
+            shape_image = list(noise.shape)
+            shape_image[1] = extra_channels
+            image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
+        else:
+            image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+            for i in range(0, image.shape[1], 16):
+                image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
+            image = utils.resize_to_batch_size(image, noise.shape[0])
 
-        image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
-        image = self.process_latent_in(image)
-        image = utils.resize_to_batch_size(image, noise.shape[0])
-
-        if not self.image_to_video:
+        if not self.image_to_video or extra_channels == image.shape[1]:
             return image
 
         mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
         if mask is None:
             mask = torch.zeros_like(noise)[:, :4]
         else:
-            mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
+            if mask.shape[1] != 4:
+                mask = torch.mean(mask, dim=1, keepdim=True)
+            mask = 1.0 - mask
             mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
             if mask.shape[-3] < noise.shape[-3]:
                 mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
-            mask = mask.repeat(1, 4, 1, 1, 1)
+            if mask.shape[1] == 1:
+                mask = mask.repeat(1, 4, 1, 1, 1)
             mask = utils.resize_to_batch_size(mask, noise.shape[0])
 
         return torch.cat((mask, image), dim=1)

From 73f10eaf7c4bbcbab7dab9435e3bd61d61ba0c2d Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Thu, 27 Mar 2025 18:55:13 +0100
Subject: [PATCH 12/17] add uptream commits

---
 comfy_extras/nodes_cfg.py |  45 ++++++++++++++++
 comfy_extras/nodes_wan.py | 105 ++++++++++++++++++++++++++++++++++++++
 nodes.py                  |   1 +
 3 files changed, 151 insertions(+)
 create mode 100644 comfy_extras/nodes_cfg.py

diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py
new file mode 100644
index 00000000..1fb68664
--- /dev/null
+++ b/comfy_extras/nodes_cfg.py
@@ -0,0 +1,45 @@
+import torch
+
+# https://github.com/WeichenFan/CFG-Zero-star
+def optimized_scale(positive, negative):
+    positive_flat = positive.reshape(positive.shape[0], -1)
+    negative_flat = negative.reshape(negative.shape[0], -1)
+
+    # Calculate dot production
+    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
+
+    # Squared norm of uncondition
+    squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
+
+    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+    st_star = dot_product / squared_norm
+
+    return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
+
+class CFGZeroStar:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {"required": {"model": ("MODEL",),
+                            }}
+    RETURN_TYPES = ("MODEL",)
+    RETURN_NAMES = ("patched_model",)
+    FUNCTION = "patch"
+    CATEGORY = "advanced/guidance"
+
+    def patch(self, model):
+        m = model.clone()
+        def cfg_zero_star(args):
+            guidance_scale = args['cond_scale']
+            x = args['input']
+            cond_p = args['cond_denoised']
+            uncond_p = args['uncond_denoised']
+            out = args["denoised"]
+            alpha = optimized_scale(x - cond_p, x - uncond_p)
+
+            return out + uncond_p * (alpha - 1.0)  + guidance_scale * uncond_p * (1.0 - alpha)
+        m.set_model_sampler_post_cfg_function(cfg_zero_star)
+        return (m, )
+
+NODE_CLASS_MAPPINGS = {
+    "CFGZeroStar": CFGZeroStar
+}
diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py
index dc30eb54..2d0f31ac 100644
--- a/comfy_extras/nodes_wan.py
+++ b/comfy_extras/nodes_wan.py
@@ -3,6 +3,7 @@ import node_helpers
 import torch
 import comfy.model_management
 import comfy.utils
+import comfy.latent_formats
 
 
 class WanImageToVideo:
@@ -49,6 +50,110 @@ class WanImageToVideo:
         return (positive, negative, out_latent)
 
 
+class WanFunControlToVideo:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {"required": {"positive": ("CONDITIONING", ),
+                             "negative": ("CONDITIONING", ),
+                             "vae": ("VAE", ),
+                             "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+                             "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+                             "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
+                             "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+                },
+                "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
+                             "start_image": ("IMAGE", ),
+                             "control_video": ("IMAGE", ),
+                }}
+
+    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
+    RETURN_NAMES = ("positive", "negative", "latent")
+    FUNCTION = "encode"
+
+    CATEGORY = "conditioning/video_models"
+
+    def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
+        latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+        concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+        concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
+        concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
+
+        if start_image is not None:
+            start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+            concat_latent_image = vae.encode(start_image[:, :, :, :3])
+            concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
+
+        if control_video is not None:
+            control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+            concat_latent_image = vae.encode(control_video[:, :, :, :3])
+            concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
+
+        positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
+        negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
+
+        if clip_vision_output is not None:
+            positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
+            negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
+
+        out_latent = {}
+        out_latent["samples"] = latent
+        return (positive, negative, out_latent)
+
+class WanFunInpaintToVideo:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {"required": {"positive": ("CONDITIONING", ),
+                             "negative": ("CONDITIONING", ),
+                             "vae": ("VAE", ),
+                             "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+                             "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+                             "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
+                             "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+                },
+                "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
+                             "start_image": ("IMAGE", ),
+                             "end_image": ("IMAGE", ),
+                }}
+
+    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
+    RETURN_NAMES = ("positive", "negative", "latent")
+    FUNCTION = "encode"
+
+    CATEGORY = "conditioning/video_models"
+
+    def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
+        latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+        if start_image is not None:
+            start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+        if end_image is not None:
+            end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+
+        image = torch.ones((length, height, width, 3)) * 0.5
+        mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
+
+        if start_image is not None:
+            image[:start_image.shape[0]] = start_image
+            mask[:, :, :start_image.shape[0] + 3] = 0.0
+
+        if end_image is not None:
+            image[-end_image.shape[0]:] = end_image
+            mask[:, :, -end_image.shape[0]:] = 0.0
+
+        concat_latent_image = vae.encode(image[:, :, :, :3])
+        mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
+        positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
+        negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
+
+        if clip_vision_output is not None:
+            positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
+            negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
+
+        out_latent = {}
+        out_latent["samples"] = latent
+        return (positive, negative, out_latent)
+
 NODE_CLASS_MAPPINGS = {
     "WanImageToVideo": WanImageToVideo,
+    "WanFunControlToVideo": WanFunControlToVideo,
+    "WanFunInpaintToVideo": WanFunInpaintToVideo,
 }
diff --git a/nodes.py b/nodes.py
index f506c8a9..42e7805e 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2269,6 +2269,7 @@ def init_builtin_extra_nodes():
         "nodes_lotus.py",
         "nodes_hunyuan3d.py",
         "nodes_primitive.py",
+        "nodes_cfg.py",
     ]
 
     import_failed = []

From cb6ece9a1857bbbc8dbc0341d403e8d8950ce78e Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Sun, 6 Apr 2025 09:52:32 +0200
Subject: [PATCH 13/17] set modelType.FLOW, will cause beta scheduler to work
 properly

---
 comfy/model_base.py       | 2 +-
 comfy/supported_models.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/comfy/model_base.py b/comfy/model_base.py
index 2a430362..756a479c 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1059,7 +1059,7 @@ class Hunyuan3Dv2(BaseModel):
         return out
 
 class Chroma(BaseModel):
-    def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+    def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
         super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
 
     def concat_cond(self, **kwargs):
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 0fe97cb4..9b9c45a2 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1045,7 +1045,7 @@ class Chroma(supported_models_base.BASE):
     supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
 
     def get_model(self, state_dict, prefix="", device=None):
-        out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
+        out = model_base.Chroma(self, model_type=model_base.ModelType.FLOW, device=device)
         return out
 
     def clip_target(self, state_dict={}):

From e1af4137224fd79f12b9eebf68943bd307276a9d Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Tue, 8 Apr 2025 12:48:25 +0200
Subject: [PATCH 14/17] Adjust memory usage factor and remove unnecessary code

---
 comfy/supported_models.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 9b9c45a2..c913545f 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1034,18 +1034,16 @@ class Chroma(supported_models_base.BASE):
     }
 
     sampling_settings = {
-        "multiplier": 1.0,
-        "shift": 1.0,
     }
 
     latent_format = comfy.latent_formats.Flux
 
-    memory_usage_factor = 1.8
+    memory_usage_factor = 3.2
 
     supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
 
     def get_model(self, state_dict, prefix="", device=None):
-        out = model_base.Chroma(self, model_type=model_base.ModelType.FLOW, device=device)
+        out = model_base.Chroma(self, device=device)
         return out
 
     def clip_target(self, state_dict={}):

From 3d375a153e6d49c7d51986716ef0daa5235d7803 Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Tue, 8 Apr 2025 12:56:02 +0200
Subject: [PATCH 15/17] fix mistake

---
 comfy/supported_models.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index c913545f..f5c79f70 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
 import comfy.text_encoders.cosmos
 import comfy.text_encoders.lumina2
 import comfy.text_encoders.wan
+import comfy.text_encoders.chroma
 
 from . import supported_models_base
 from . import latent_formats
@@ -1034,6 +1035,7 @@ class Chroma(supported_models_base.BASE):
     }
 
     sampling_settings = {
+        "multiplier": 1.0,
     }
 
     latent_format = comfy.latent_formats.Flux
@@ -1042,6 +1044,7 @@ class Chroma(supported_models_base.BASE):
 
     supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
 
+
     def get_model(self, state_dict, prefix="", device=None):
         out = model_base.Chroma(self, device=device)
         return out

From 2c2481955d10cf791504cbb425a363bd5c7904e0 Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Fri, 11 Apr 2025 16:18:57 +0200
Subject: [PATCH 16/17] reduce code duplication

---
 comfy/ldm/chroma/layers.py | 114 +++++--------------------------------
 comfy/ldm/chroma/model.py  |  22 ++-----
 2 files changed, 20 insertions(+), 116 deletions(-)

diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 8ad3c72d..410b7121 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -5,112 +5,26 @@ import torch
 from torch import Tensor, nn
 
 from .math import attention, rope
-import comfy.ops
+from comfy.ldm.flux.layers import (
+    MLPEmbedder,
+    RMSNorm,
+    QKNorm,
+    SelfAttention,
+    ModulationOut,
+)
 import comfy.ldm.common_dit
 
 
-class EmbedND(nn.Module):
-    def __init__(self, dim: int, theta: int, axes_dim: list):
-        super().__init__()
-        self.dim = dim
-        self.theta = theta
-        self.axes_dim = axes_dim
 
-    def forward(self, ids: Tensor) -> Tensor:
-        n_axes = ids.shape[-1]
-        emb = torch.cat(
-            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
-            dim=-3,
+class ChromaModulationOut(ModulationOut):
+    @classmethod
+    def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
+        return cls(
+            shift=tensor[:, offset : offset + 1, :],
+            scale=tensor[:, offset + 1 : offset + 2, :],
+            gate=tensor[:, offset + 2 : offset + 3, :],
         )
 
-        return emb.unsqueeze(1)
-
-
-def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
-    """
-    Create sinusoidal timestep embeddings.
-    :param t: a 1-D Tensor of N indices, one per batch element.
-                      These may be fractional.
-    :param dim: the dimension of the output.
-    :param max_period: controls the minimum frequency of the embeddings.
-    :return: an (N, D) Tensor of positional embeddings.
-    """
-    t = time_factor * t
-    half = dim // 2
-    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
-
-    args = t[:, None].float() * freqs[None]
-    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
-    if dim % 2:
-        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
-    if torch.is_floating_point(t):
-        embedding = embedding.to(t)
-    return embedding
-
-class MLPEmbedder(nn.Module):
-    def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
-        super().__init__()
-        self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
-        self.silu = nn.SiLU()
-        self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
-
-    def forward(self, x: Tensor) -> Tensor:
-        return self.out_layer(self.silu(self.in_layer(x)))
-
-
-class RMSNorm(torch.nn.Module):
-    def __init__(self, dim: int, dtype=None, device=None, operations=None):
-        super().__init__()
-        self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
-
-    def forward(self, x: Tensor):
-        return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
-
-
-class QKNorm(torch.nn.Module):
-    def __init__(self, dim: int, dtype=None, device=None, operations=None):
-        super().__init__()
-        self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
-        self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
-
-    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
-        q = self.query_norm(q)
-        k = self.key_norm(k)
-        return q.to(v), k.to(v)
-
-
-class SelfAttention(nn.Module):
-    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
-        super().__init__()
-        self.num_heads = num_heads
-        head_dim = dim // num_heads
-
-        self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
-        self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
-        self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
-
-
-@dataclass
-class ModulationOut:
-    shift: Tensor
-    scale: Tensor
-    gate: Tensor
-
-
-class Modulation(nn.Module):
-    def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
-        super().__init__()
-        self.is_double = double
-        self.multiplier = 6 if double else 3
-        self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
-
-    def forward(self, vec: Tensor) -> tuple:
-        out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
-
-        return (
-            ModulationOut(*out[:3]),
-            ModulationOut(*out[3:]) if self.is_double else None,
-        )
 
 
 
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index a956ad05..636748fc 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -7,15 +7,17 @@ from torch import Tensor, nn
 from einops import rearrange, repeat
 import comfy.ldm.common_dit
 
+from comfy.ldm.flux.layers import (
+    EmbedND,
+    timestep_embedding,
+)
+
 from .layers import (
     DoubleStreamBlock,
-    EmbedND,
     LastLayer,
-    MLPEmbedder,
     SingleStreamBlock,
-    timestep_embedding,
     Approximator,
-    ModulationOut
+    ChromaModulationOut,
 )
 
 
@@ -39,14 +41,6 @@ class ChromaParams:
     n_layers: int
 
 
-class ChromaModulationOut(ModulationOut):
-    @classmethod
-    def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
-        return cls(
-            shift=tensor[:, offset : offset + 1, :],
-            scale=tensor[:, offset + 1 : offset + 2, :],
-            gate=tensor[:, offset + 2 : offset + 3, :],
-        )
 
 
 class Chroma(nn.Module):
@@ -77,7 +71,6 @@ class Chroma(nn.Module):
         self.n_layers = params.n_layers
         self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
         self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
-        self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
         self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
         # set as nn identity for now, will overwrite it later.
         self.distilled_guidance_layer = Approximator(
@@ -88,9 +81,6 @@ class Chroma(nn.Module):
                     dtype=dtype, device=device, operations=operations
                 )
 
-        self.guidance_in = (
-            MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity()
-        )
 
         self.double_blocks = nn.ModuleList(
             [

From cd3d2d5c62a8ee5d2e0009b67ddff661726c61e7 Mon Sep 17 00:00:00 2001
From: silveroxides <ishimarukaito@gmail.com>
Date: Fri, 11 Apr 2025 16:22:04 +0200
Subject: [PATCH 17/17] remove unused imports

---
 comfy/ldm/chroma/layers.py | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 410b7121..dd0b72f7 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -1,10 +1,7 @@
-import math
-from dataclasses import dataclass
-
 import torch
 from torch import Tensor, nn
 
-from .math import attention, rope
+from .math import attention
 from comfy.ldm.flux.layers import (
     MLPEmbedder,
     RMSNorm,
@@ -12,7 +9,6 @@ from comfy.ldm.flux.layers import (
     SelfAttention,
     ModulationOut,
 )
-import comfy.ldm.common_dit