From 2c2481955d10cf791504cbb425a363bd5c7904e0 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 11 Apr 2025 16:18:57 +0200 Subject: [PATCH] reduce code duplication --- comfy/ldm/chroma/layers.py | 114 +++++-------------------------------- comfy/ldm/chroma/model.py | 22 ++----- 2 files changed, 20 insertions(+), 116 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 8ad3c72d4..410b7121f 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -5,112 +5,26 @@ import torch from torch import Tensor, nn from .math import attention, rope -import comfy.ops +from comfy.ldm.flux.layers import ( + MLPEmbedder, + RMSNorm, + QKNorm, + SelfAttention, + ModulationOut, +) import comfy.ldm.common_dit -class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - def forward(self, ids: Tensor) -> Tensor: - n_axes = ids.shape[-1] - emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - dim=-3, +class ChromaModulationOut(ModulationOut): + @classmethod + def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: + return cls( + shift=tensor[:, offset : offset + 1, :], + scale=tensor[:, offset + 1 : offset + 2, :], + gate=tensor[:, offset + 2 : offset + 3, :], ) - return emb.unsqueeze(1) - - -def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - t = time_factor * t - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) - - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - if torch.is_floating_point(t): - embedding = embedding.to(t) - return embedding - -class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) - self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) - - def forward(self, x: Tensor) -> Tensor: - return self.out_layer(self.silu(self.in_layer(x))) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) - - def forward(self, x: Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) - - -class QKNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: - q = self.query_norm(q) - k = self.key_norm(k) - return q.to(v), k.to(v) - - -class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - - self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) - - -@dataclass -class ModulationOut: - shift: Tensor - scale: Tensor - gate: Tensor - - -class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): - super().__init__() - self.is_double = double - self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) - - def forward(self, vec: Tensor) -> tuple: - out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) - - return ( - ModulationOut(*out[:3]), - ModulationOut(*out[3:]) if self.is_double else None, - ) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a956ad054..636748fc5 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -7,15 +7,17 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +from comfy.ldm.flux.layers import ( + EmbedND, + timestep_embedding, +) + from .layers import ( DoubleStreamBlock, - EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding, Approximator, - ModulationOut + ChromaModulationOut, ) @@ -39,14 +41,6 @@ class ChromaParams: n_layers: int -class ChromaModulationOut(ModulationOut): - @classmethod - def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: - return cls( - shift=tensor[:, offset : offset + 1, :], - scale=tensor[:, offset + 1 : offset + 2, :], - gate=tensor[:, offset + 2 : offset + 3, :], - ) class Chroma(nn.Module): @@ -77,7 +71,6 @@ class Chroma(nn.Module): self.n_layers = params.n_layers self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) # set as nn identity for now, will overwrite it later. self.distilled_guidance_layer = Approximator( @@ -88,9 +81,6 @@ class Chroma(nn.Module): dtype=dtype, device=device, operations=operations ) - self.guidance_in = ( - MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity() - ) self.double_blocks = nn.ModuleList( [