From e5ea112a90ed5c22f0114dc29f5a9e7d8a897edf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Feb 2025 03:56:00 -0500 Subject: [PATCH] Support Lumina 2 model. --- comfy/ldm/lumina/model.py | 670 ++++++++++++++++++++ comfy/ldm/modules/diffusionmodules/mmdit.py | 2 +- comfy/model_base.py | 17 + comfy/model_detection.py | 17 +- comfy/sd.py | 11 +- comfy/sd1_clip.py | 17 +- comfy/supported_models.py | 32 +- comfy/text_encoders/llama.py | 132 +++- comfy/text_encoders/lumina2.py | 44 ++ comfy/text_encoders/spiece_tokenizer.py | 14 +- nodes.py | 4 +- 11 files changed, 921 insertions(+), 39 deletions(-) create mode 100644 comfy/ldm/lumina/model.py create mode 100644 comfy/text_encoders/lumina2.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py new file mode 100644 index 000000000..24c6d80f2 --- /dev/null +++ b/comfy/ldm/lumina/model.py @@ -0,0 +1,670 @@ +# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py + +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm +from comfy.ldm.modules.attention import optimized_attention_masked + + +def modulate(x, scale): + return x * (1 + scale.unsqueeze(1)) + +############################################################################# +# Core NextDiT Model # +############################################################################# + + +class JointAttention(nn.Module): + """Multi-head attention module.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + operation_settings={}, + ): + """ + Initialize the Attention module. + + Args: + dim (int): Number of input dimensions. + n_heads (int): Number of heads. + n_kv_heads (Optional[int]): Number of kv heads, if using GQA. + + """ + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.qkv = operation_settings.get("operations").Linear( + dim, + (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.out = operation_settings.get("operations").Linear( + n_heads * self.head_dim, + dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) + self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) + else: + self.q_norm = self.k_norm = nn.Identity() + + @staticmethod + def apply_rotary_emb( + x_in: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency + tensor. + + This function applies rotary embeddings to the given query 'xq' and + key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The + input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors + contain rotary embeddings and are returned as real tensors. + + Args: + x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex + exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor + and key tensor with rotary embeddings. + """ + + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + """ + + Args: + x: + x_mask: + freqs_cis: + + Returns: + + """ + bsz, seqlen, _ = x.shape + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) + + return self.out(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + operation_settings={}, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + """ + super().__init__() + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = operation_settings.get("operations").Linear( + dim, + hidden_dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.w2 = operation_settings.get("operations").Linear( + hidden_dim, + dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.w3 = operation_settings.get("operations").Linear( + dim, + hidden_dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + # @torch.compile + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class JointTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + operation_settings={}, + ) -> None: + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + dim (int): Embedding dimension of the input features. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): + ffn_dim_multiplier (float): + norm_eps (float): + + """ + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + operation_settings=operation_settings, + ) + self.layer_id = layer_id + self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor]=None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and + feedforward layers. + + """ + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + self.attention( + modulate(self.attention_norm1(x), scale_msa), + x_mask, + freqs_cis, + ) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + self.feed_forward( + modulate(self.ffn_norm1(x), scale_mlp), + ) + ) + else: + assert adaln_input is None + x = x + self.attention_norm2( + self.attention( + self.attention_norm1(x), + x_mask, + freqs_cis, + ) + ) + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of NextDiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + super().__init__() + self.norm_final = operation_settings.get("operations").LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.linear = operation_settings.get("operations").Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(hidden_size, 1024), + hidden_size, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + def forward(self, x, c): + scale = self.adaLN_modulation(c) + x = modulate(self.norm_final(x), scale) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512) + ): + super().__init__() + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + + def __call__(self, ids: torch.Tensor): + self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64) + result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1) + + +class NextDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 4, + dim: int = 4096, + n_layers: int = 32, + n_refiner_layers: int = 2, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + qk_norm: bool = False, + cap_feat_dim: int = 5120, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (1, 512, 512), + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + + self.x_embedder = operation_settings.get("operations").Linear( + in_features=patch_size * patch_size * in_channels, + out_features=dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + self.noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), + operation_settings.get("operations").Linear( + cap_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + self.layers = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + operation_settings=operation_settings, + ) + for layer_id in range(n_layers) + ] + ) + self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + + assert (dim // n_heads) == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.dim = dim + self.n_heads = n_heads + + def unpatchify( + self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False + ) -> List[torch.Tensor]: + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + pH = pW = self.patch_size + imgs = [] + for i in range(x.size(0)): + H, W = img_size[i] + begin = cap_size[i] + end = begin + (H // pH) * (W // pW) + imgs.append( + x[i][begin:end] + .view(H // pH, W // pW, pH, pW, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + + if return_tensor: + imgs = torch.stack(imgs, dim=0) + return imgs + + def patchify_and_embed( + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = len(x) + pH = pW = self.patch_size + device = x[0].device + dtype = x[0].dtype + + if cap_mask is not None: + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + else: + l_effective_cap_len = [num_tokens] * bsz + + if cap_mask is not None and not torch.is_floating_point(cap_mask): + cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + + img_sizes = [(img.size(1), img.size(2)) for img in x] + l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + + max_seq_len = max( + (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) + ) + max_cap_len = max(l_effective_cap_len) + max_img_len = max(l_effective_img_len) + + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + H, W = img_sizes[i] + H_tokens, W_tokens = H // pH, W // pW + assert H_tokens * W_tokens == img_len + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len:cap_len+img_len, 0] = cap_len + row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + position_ids[i, cap_len:cap_len+img_len, 1] = row_ids + position_ids[i, cap_len:cap_len+img_len, 2] = col_ids + + freqs_cis = self.rope_embedder(position_ids) + + # build freqs_cis for cap and image individually + cap_freqs_cis_shape = list(freqs_cis.shape) + # cap_freqs_cis_shape[1] = max_cap_len + cap_freqs_cis_shape[1] = cap_feats.shape[1] + cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + img_freqs_cis_shape = list(freqs_cis.shape) + img_freqs_cis_shape[1] = max_img_len + img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + + # refine context + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + # refine image + flat_x = [] + for i in range(bsz): + img = x[i] + C, H, W = img.size() + img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) + flat_x.append(img) + x = flat_x + padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) + padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) + for i in range(bsz): + padded_img_embed[i, :l_effective_img_len[i]] = x[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + padded_img_embed = self.x_embedder(padded_img_embed) + for layer in self.noise_refiner: + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + + if cap_mask is not None: + mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) + mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] + else: + mask = None + + padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + + padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] + padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + + return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + + + # def forward(self, x, t, cap_feats, cap_mask): + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + """ + Forward pass of NextDiT. + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of text tokens/features + """ + + t = self.t_embedder(t, dtype=x.dtype) # (N, D) + adaln_input = t + + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + + x_is_tensor = isinstance(x, torch.Tensor) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) + freqs_cis = freqs_cis.to(x.device) + + for layer in self.layers: + x = layer(x, mask, freqs_cis, adaln_input) + + x = self.final_layer(x, adaln_input) + x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor) + + return -x + + @staticmethod + def precompute_freqs_cis( + dim: List[int], + end: List[int], + theta: float = 10000.0, + ): + """ + Precompute the frequency tensor for complex exponentials (cis) with + given dimensions. + + This function calculates a frequency tensor with complex exponentials + using the given dimension 'dim' and the end index 'end'. The 'theta' + parameter scales the frequencies. The returned tensor contains complex + values in complex64 data type. + + Args: + dim (list): Dimension of the frequency tensor. + end (list): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. + Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex + exponentials. + """ + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index e70f4431f..eaf3e73a4 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -321,7 +321,7 @@ class SelfAttention(nn.Module): class RMSNorm(torch.nn.Module): def __init__( - self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None + self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs ): """ Initialize the RMSNorm normalization layer. diff --git a/comfy/model_base.py b/comfy/model_base.py index cd05bbdfe..4d1b83a4a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -34,6 +34,7 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model +import comfy.ldm.lumina.model import comfy.model_management import comfy.patcher_extension @@ -904,3 +905,19 @@ class CosmosVideo(BaseModel): latent_image = latent_image + noise latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) + +class Lumina2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ba96ebe85..2644dd0dc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -239,7 +239,7 @@ def detect_unet_config(state_dict, key_prefix): dit_config["micro_condition"] = False return dit_config - if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: + if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos dit_config = {} dit_config["image_model"] = "cosmos" dit_config["max_img_h"] = 240 @@ -284,6 +284,21 @@ def detect_unet_config(state_dict, key_prefix): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + dit_config = {} + dit_config["image_model"] = "lumina2" + dit_config["patch_size"] = 2 + dit_config["in_channels"] = 16 + dit_config["dim"] = 2304 + dit_config["cap_feat_dim"] = 2304 + dit_config["n_layers"] = 26 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["qk_norm"] = True + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index d7e89f726..eabf0bda0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -36,6 +36,7 @@ import comfy.text_encoders.genmo import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos +import comfy.text_encoders.lumina2 import comfy.model_patcher import comfy.lora @@ -657,6 +658,7 @@ class CLIPType(Enum): HUNYUAN_VIDEO = 9 PIXART = 10 COSMOS = 11 + LUMINA2 = 12 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -675,6 +677,7 @@ class TEModel(Enum): T5_BASE = 6 LLAMA3_8 = 7 T5_XXL_OLD = 8 + GEMMA_2_2B = 9 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -693,6 +696,8 @@ def detect_te_model(sd): return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: return TEModel.T5_BASE + if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + return TEModel.GEMMA_2_2B if "model.layers.0.post_attention_layernorm.weight" in sd: return TEModel.LLAMA3_8 return None @@ -730,6 +735,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if "text_projection" in clip_data[i]: clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node + tokenizer_data = {} clip_target = EmptyClass() clip_target.params = {} if len(clip_data) == 1: @@ -769,6 +775,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.T5_BASE: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model == TEModel.GEMMA_2_2B: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) else: if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) @@ -798,7 +808,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer parameters = 0 - tokenizer_data = {} for c in clip_data: parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 85518afd9..d2457731d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -421,10 +421,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") - self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) + self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.max_length = max_length self.min_length = min_length self.end_token = None @@ -585,9 +585,14 @@ class SDTokenizer: return {} class SD1Tokenizer: - def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): - self.clip_name = clip_name - self.clip = "clip_{}".format(self.clip_name) + def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None): + if name is not None: + self.clip_name = name + self.clip = "{}".format(self.clip_name) + else: + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)) @@ -600,7 +605,7 @@ class SD1Tokenizer: return getattr(self, self.clip).untokenize(token_weight_pair) def state_dict(self): - return {} + return getattr(self, self.clip).state_dict() class SD1CheckpointClipModel(SDClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ff0bea418..7aa152480 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -15,6 +15,7 @@ import comfy.text_encoders.genmo import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos +import comfy.text_encoders.lumina2 from . import supported_models_base from . import latent_formats @@ -865,6 +866,35 @@ class CosmosI2V(CosmosT2V): out = model_base.CosmosVideo(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V] +class Lumina2(supported_models_base.BASE): + unet_config = { + "image_model": "lumina2", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 6.0, + } + + memory_usage_factor = 1.2 + + unet_extra_config = {} + latent_format = latent_formats.Flux + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Lumina2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ad4b4623e..3f234015a 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Any @@ -21,15 +20,41 @@ class Llama2Config: max_position_embeddings: int = 8192 rms_norm_eps: float = 1e-5 rope_theta: float = 500000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + +@dataclass +class Gemma2_2B_Config: + vocab_size: int = 256000 + hidden_size: int = 2304 + intermediate_size: int = 9216 + num_hidden_layers: int = 26 + num_attention_heads: int = 8 + num_key_value_heads: int = 4 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + transformer_type: str = "gemma2" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None): + def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + self.add = add def forward(self, x: torch.Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) + w = self.weight + if self.add: + w = w + 1.0 + + return comfy.ldm.common_dit.rms_norm(x, w, self.eps) + def rotate_half(x): @@ -68,13 +93,15 @@ class Attention(nn.Module): self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.hidden_size = config.hidden_size - self.head_dim = self.hidden_size // self.num_heads + + self.head_dim = config.head_dim + self.inner_size = self.num_heads * self.head_dim ops = ops or nn - self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype) + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype) self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype) self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype) - self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) def forward( self, @@ -84,7 +111,6 @@ class Attention(nn.Module): optimized_attention=None, ): batch_size, seq_length, _ = hidden_states.shape - xq = self.q_proj(hidden_states) xk = self.k_proj(hidden_states) xv = self.v_proj(hidden_states) @@ -108,9 +134,13 @@ class MLP(nn.Module): self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) + if config.mlp_activation == "silu": + self.activation = torch.nn.functional.silu + elif config.mlp_activation == "gelu_pytorch_tanh": + self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh") def forward(self, x): - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): @@ -146,6 +176,45 @@ class TransformerBlock(nn.Module): return x +class TransformerBlockGemma2(nn.Module): + def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + super().__init__() + self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + optimized_attention=None, + ): + # Self Attention + residual = x + x = self.input_layernorm(x) + x = self.self_attn( + hidden_states=x, + attention_mask=attention_mask, + freqs_cis=freqs_cis, + optimized_attention=optimized_attention, + ) + + x = self.post_attention_layernorm(x) + x = residual + x + + # MLP + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + + return x + class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -158,17 +227,27 @@ class Llama2_(nn.Module): device=device, dtype=dtype ) + if self.config.transformer_type == "gemma2": + transformer = TransformerBlockGemma2 + self.normalize_in = True + else: + transformer = TransformerBlock + self.normalize_in = False + self.layers = nn.ModuleList([ - TransformerBlock(config, device=device, dtype=dtype, ops=ops) + transformer(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): x = self.embed_tokens(x, out_dtype=dtype) - freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads, + if self.normalize_in: + x *= self.config.hidden_size ** 0.5 + + freqs_cis = precompute_freqs_cis(self.config.head_dim, x.shape[1], self.config.rope_theta, device=x.device) @@ -206,16 +285,7 @@ class Llama2_(nn.Module): return x, intermediate - -class Llama2(torch.nn.Module): - def __init__(self, config_dict, dtype, device, operations): - super().__init__() - config = Llama2Config(**config_dict) - self.num_layers = config.num_hidden_layers - - self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) - self.dtype = dtype - +class BaseLlama: def get_input_embeddings(self): return self.model.embed_tokens @@ -224,3 +294,23 @@ class Llama2(torch.nn.Module): def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) + + +class Llama2(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Llama2Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + +class Gemma2_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma2_2B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py new file mode 100644 index 000000000..166d13281 --- /dev/null +++ b/comfy/text_encoders/lumina2.py @@ -0,0 +1,44 @@ +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer +import comfy.text_encoders.llama + + +class Gemma2BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + +class LuminaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer) + + +class Gemma2_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class LuminaModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_scaled_fp8=None): + class LuminaTEModel_(LuminaModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return LuminaTEModel_ diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index cbaa99ba5..21df4f863 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,21 +1,21 @@ import torch class SPieceTokenizer: - add_eos = True - @staticmethod - def from_pretrained(path): - return SPieceTokenizer(path) + def from_pretrained(path, **kwargs): + return SPieceTokenizer(path, **kwargs) - def __init__(self, tokenizer_path): + def __init__(self, tokenizer_path, add_bos=False, add_eos=True): + self.add_bos = add_bos + self.add_eos = add_eos import sentencepiece if torch.is_tensor(tokenizer_path): tokenizer_path = tokenizer_path.numpy().tobytes() if isinstance(tokenizer_path, bytes): - self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos) + self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) else: - self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos) + self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) def get_vocab(self): out = {} diff --git a/nodes.py b/nodes.py index 968f0f9ad..ba9c4e4bb 100644 --- a/nodes.py +++ b/nodes.py @@ -914,7 +914,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -941,6 +941,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.PIXART elif type == "cosmos": clip_type = comfy.sd.CLIPType.COSMOS + elif type == "lumina2": + clip_type = comfy.sd.CLIPType.LUMINA2 else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION