From 251f54a2ad5a2676481337b930e7eceb735b132c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Jun 2025 04:05:23 -0700 Subject: [PATCH] Basic initial support for cosmos predict2 text to image 2B and 14B models. (#8517) --- comfy/ldm/cosmos/blocks.py | 10 - comfy/ldm/cosmos/position_embedding.py | 4 +- comfy/ldm/cosmos/predict2.py | 868 +++++++++++++++++++++++++ comfy/model_base.py | 47 +- comfy/model_detection.py | 47 ++ comfy/model_sampling.py | 31 + comfy/supported_models.py | 40 +- 7 files changed, 1021 insertions(+), 26 deletions(-) create mode 100644 comfy/ldm/cosmos/predict2.py diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index a12f892d2..5c4356a3f 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -26,16 +26,6 @@ from torch import nn from comfy.ldm.modules.attention import optimized_attention -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - - def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index 4d6a58dba..44197a597 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -66,6 +66,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb): h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, device=None, **kwargs, # used for compatibility with other positional embeddings; unused in this class ): @@ -75,6 +76,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb): self.base_fps = base_fps self.max_h = len_h self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation dim = head_dim dim_h = dim // 6 * 2 @@ -143,7 +145,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb): half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs) # apply sequence scaling in temporal dimension - if fps is None: # image case + if fps is None or self.enable_fps_modulation is False: # image case half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs) else: half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py new file mode 100644 index 000000000..3b91b3f6e --- /dev/null +++ b/comfy/ldm/cosmos/predict2.py @@ -0,0 +1,868 @@ +# original code from: https://github.com/nvidia-cosmos/cosmos-predict2 + +import torch +from torch import nn +from einops import rearrange +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple +import math + +from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis +from torchvision import transforms + +from comfy.ldm.modules.attention import optimized_attention + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + result_B_S_HD = rearrange( + optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)" + ) + + return result_B_S_HD + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + result = self.attn_op(q, k, v) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + ) + x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + for block in self.blocks: + x_B_T_H_W_D = block( + x_B_T_H_W_D, + t_embedding_B_T_D, + crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + return x_B_C_Tt_Hp_Wp diff --git a/comfy/model_base.py b/comfy/model_base.py index e0c2bcaa8..be72ddd17 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -34,6 +34,7 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model +import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model @@ -48,6 +49,7 @@ import comfy.ops from enum import Enum from . import utils import comfy.latent_formats +import comfy.model_sampling import math from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -63,38 +65,39 @@ class ModelType(Enum): V_PREDICTION_CONTINUOUS = 7 FLUX = 8 IMG_TO_IMG = 9 - - -from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV + FLOW_COSMOS = 10 def model_sampling(model_config, model_type): - s = ModelSamplingDiscrete + s = comfy.model_sampling.ModelSamplingDiscrete if model_type == ModelType.EPS: - c = EPS + c = comfy.model_sampling.EPS elif model_type == ModelType.V_PREDICTION: - c = V_PREDICTION + c = comfy.model_sampling.V_PREDICTION elif model_type == ModelType.V_PREDICTION_EDM: - c = V_PREDICTION - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.FLOW: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingDiscreteFlow elif model_type == ModelType.STABLE_CASCADE: - c = EPS - s = StableCascadeSampling + c = comfy.model_sampling.EPS + s = comfy.model_sampling.StableCascadeSampling elif model_type == ModelType.EDM: - c = EDM - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.EDM + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.V_PREDICTION_CONTINUOUS: - c = V_PREDICTION - s = ModelSamplingContinuousV + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousV elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux elif model_type == ModelType.IMG_TO_IMG: c = comfy.model_sampling.IMG_TO_IMG + elif model_type == ModelType.FLOW_COSMOS: + c = comfy.model_sampling.COSMOS_RFLOW + s = comfy.model_sampling.ModelSamplingCosmosRFlow class ModelSampling(s, c): pass @@ -998,6 +1001,22 @@ class CosmosVideo(BaseModel): latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) +class CosmosPredict2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT) + self.image_to_video = image_to_video + if self.image_to_video: + self.concat_keys = ("mask_inverted",) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74f539598..b8fef17ad 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -407,6 +407,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["text_emb_dim"] = 2048 return dit_config + if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 + dit_config = {} + dit_config["image_model"] = "cosmos_predict2" + dit_config["max_img_h"] = 240 + dit_config["max_img_w"] = 240 + dit_config["max_frames"] = 128 + concat_padding_mask = True + dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["out_channels"] = 16 + dit_config["patch_spatial"] = 2 + dit_config["patch_temporal"] = 1 + dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["concat_padding_mask"] = concat_padding_mask + dit_config["crossattn_emb_channels"] = 1024 + dit_config["pos_emb_cls"] = "rope3d" + dit_config["pos_emb_learnable"] = True + dit_config["pos_emb_interpolation"] = "crop" + dit_config["min_fps"] = 1 + dit_config["max_fps"] = 30 + + dit_config["use_adaln_lora"] = True + dit_config["adaln_lora_dim"] = 256 + if dit_config["model_channels"] == 2048: + dit_config["num_blocks"] = 28 + dit_config["num_heads"] = 16 + elif dit_config["model_channels"] == 5120: + dit_config["num_blocks"] = 36 + dit_config["num_heads"] = 40 + + if dit_config["in_channels"] == 16: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 4.0 + dit_config["rope_w_extrapolation_ratio"] = 4.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["in_channels"] == 17: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 3.0 + dit_config["rope_w_extrapolation_ratio"] = 3.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + + dit_config["extra_h_extrapolation_ratio"] = 1.0 + dit_config["extra_w_extrapolation_ratio"] = 1.0 + dit_config["extra_t_extrapolation_ratio"] = 1.0 + dit_config["rope_enable_fps_modulation"] = False + + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 7e7291476..b240b7f29 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -77,6 +77,25 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class COSMOS_RFLOW: + def calculate_input(self, sigma, noise): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise * (1.0 - sigma) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * (1.0 - sigma) - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): @@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module): if percent >= 1.0: return 0.0 return flux_time_shift(self.shift, 1.0, 1.0 - percent) + + +class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM): + def timestep(self, sigma): + return sigma / (sigma + 1) + + def sigma(self, timestep): + sigma_max = self.sigma_max + if timestep >= (sigma_max / (sigma_max + 1)): + return sigma_max + + return timestep / (1 - timestep) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index efe2e6b8f..2a213f24a 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -908,6 +908,44 @@ class CosmosI2V(CosmosT2V): out = model_base.CosmosVideo(self, image_to_video=True, device=device) return out +class CosmosT2IPredict2(supported_models_base.BASE): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 16, + } + + sampling_settings = { + "sigma_data": 1.0, + "sigma_max": 80.0, + "sigma_min": 0.002, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.6 #TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) + +class CosmosI2VPredict2(CosmosT2IPredict2): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 17, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, image_to_video=True, device=device) + return out + class Lumina2(supported_models_base.BASE): unet_config = { "image_model": "lumina2", @@ -1139,6 +1177,6 @@ class ACEStep(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid]