Basic initial support for cosmos predict2 text to image 2B and 14B models. (#8517)

This commit is contained in:
comfyanonymous 2025-06-13 04:05:23 -07:00 committed by GitHub
parent c6529c0d77
commit 251f54a2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1021 additions and 26 deletions

View File

@ -26,16 +26,6 @@ from torch import nn
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()

View File

@ -66,6 +66,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
enable_fps_modulation: bool = True,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
@ -75,6 +76,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
self.enable_fps_modulation = enable_fps_modulation
dim = head_dim
dim_h = dim // 6 * 2
@ -143,7 +145,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None: # image case
if fps is None or self.enable_fps_modulation is False: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)

View File

@ -0,0 +1,868 @@
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
import logging
from typing import Callable, Optional, Tuple
import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
# ---------------------- Feed Forward Network -----------------------
class GPT2FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
super().__init__()
self.activation = nn.GELU()
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
self._layer_id = None
self._dim = d_model
self._hidden_dim = d_ff
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
attention, and rearranges the output back to the original format.
The input tensor names use the following dimension conventions:
- B: batch size
- S: sequence length
- H: number of attention heads
- D: head dimension
Args:
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
Returns:
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
)
return result_B_S_HD
class Attention(nn.Module):
"""
A flexible attention module supporting both self-attention and cross-attention mechanisms.
This module implements a multi-head attention layer that can operate in either self-attention
or cross-attention mode. The mode is determined by whether a context dimension is provided.
The implementation uses scaled dot-product attention and supports optional bias terms and
dropout regularization.
Args:
query_dim (int): The dimensionality of the query vectors.
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
If None, the module operates in self-attention mode using query_dim. Default: None
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
head_dim (int, optional): The dimension of each attention head. Default: 64
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
Examples:
>>> # Self-attention with 512 dimensions and 8 heads
>>> self_attn = Attention(query_dim=512)
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
>>> out = self_attn(x) # (32, 16, 512)
>>> # Cross-attention
>>> cross_attn = Attention(query_dim=512, context_dim=256)
>>> query = torch.randn(32, 16, 512)
>>> context = torch.randn(32, 8, 256)
>>> out = cross_attn(query, context) # (32, 16, 512)
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
n_heads: int = 8,
head_dim: int = 64,
dropout: float = 0.0,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
logging.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{n_heads} heads with a dimension of {head_dim}."
)
self.is_selfattn = context_dim is None # self attention
context_dim = query_dim if context_dim is None else context_dim
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.v_norm = nn.Identity()
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
def compute_qkv(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_proj(x)
context = x if context is None else context
k = self.k_proj(context)
v = self.v_proj(context)
q, k, v = map(
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
(q, k, v),
)
def apply_norm_and_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_norm(q)
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
class Timesteps(nn.Module):
def __init__(self, num_channels: int):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
timesteps = timesteps_B_T.flatten().float()
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.in_dim = in_features
self.out_dim = out_features
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_T_3D = emb
emb_B_T_D = sample
else:
adaln_lora_B_T_3D = None
emb_B_T_D = emb
return emb_B_T_D, adaln_lora_B_T_3D
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size: int,
temporal_patch_size: int,
in_channels: int = 3,
out_channels: int = 768,
device=None, dtype=None, operations=None
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert (
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size: int,
spatial_patch_size: int,
temporal_patch_size: int,
out_channels: int,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
if use_adaln_lora:
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
)
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
scale_B_T_D, "b t d -> b t 1 1 d"
)
def _fn(
_x_B_T_H_W_D: torch.Tensor,
_norm_layer: nn.Module,
_scale_B_T_1_1_D: torch.Tensor,
_shift_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
return x_B_T_H_W_O
class Block(nn.Module):
"""
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
The block applies the following sequence:
1. Self-attention with AdaLN modulation
2. Cross-attention with AdaLN modulation
3. MLP with AdaLN modulation
Each component uses skip connections and layer normalization.
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.x_dim = x_dim
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
)
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
self.use_adaln_lora = use_adaln_lora
if self.use_adaln_lora:
self.adaln_modulation_self_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_cross_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_mlp = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
B, T, H, W, D = x_B_T_H_W_D.shape
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_self_attn,
scale_self_attn_B_T_1_1_D,
shift_self_attn_B_T_1_1_D,
)
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
return _result_B_T_H_W_D
result_B_T_H_W_D = _x_fn(
x_B_T_H_W_D,
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_mlp,
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
return x_B_T_H_W_D
class MiniTrainDIT(nn.Module):
"""
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
min_fps (int): Minimum frames per second.
max_fps (int): Maximum frames per second.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: int, # tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
# cross attention settings
crossattn_emb_channels: int = 1024,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
min_fps: int = 1,
max_fps: int = 30,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.min_fps = min_fps
self.max_fps = max_fps
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.build_pos_embed(device=device, dtype=dtype)
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.Sequential(
Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
)
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
device=device, dtype=dtype, operations=operations,
)
self.blocks = nn.ModuleList(
[
Block(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_blocks)
]
)
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
def build_pos_embed(self, device=None, dtype=None) -> None:
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
max_fps=self.max_fps,
min_fps=self.min_fps,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
enable_fps_modulation=self.rope_enable_fps_modulation,
device=device,
)
self.pos_embedder = cls_type(
**kwargs, # type: ignore
)
if self.extra_per_block_abs_pos_emb:
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs, # type: ignore
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is None:
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
else:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
x_B_C_Tt_Hp_Wp = rearrange(
x_B_T_H_W_M,
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
t=self.patch_temporal,
)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
x_B_C_T_H_W = x
timesteps_B_T = timesteps
crossattn_emb = context
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
"""
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x_B_C_T_H_W,
fps=fps,
padding_mask=padding_mask,
)
if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
# for logging purpose
affline_scale_log_info = {}
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
self.affline_scale_log_info = affline_scale_log_info
self.affline_emb = t_embedding_B_T_D
self.crossattn_emb = crossattn_emb
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
block_kwargs = {
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
**block_kwargs,
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@ -34,6 +34,7 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
@ -48,6 +49,7 @@ import comfy.ops
from enum import Enum
from . import utils
import comfy.latent_formats
import comfy.model_sampling
import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@ -63,38 +65,39 @@ class ModelType(Enum):
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
FLOW_COSMOS = 10
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
s = comfy.model_sampling.ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
c = comfy.model_sampling.EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
c = comfy.model_sampling.V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
c = comfy.model_sampling.EPS
s = comfy.model_sampling.StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
c = comfy.model_sampling.EDM
s = comfy.model_sampling.ModelSamplingContinuousEDM
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION
s = ModelSamplingContinuousV
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousV
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSampling(s, c):
pass
@ -998,6 +1001,22 @@ class CosmosVideo(BaseModel):
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class CosmosPredict2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
self.image_to_video = image_to_video
if self.image_to_video:
self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)

View File

@ -407,6 +407,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = True
dit_config["pos_emb_interpolation"] = "crop"
dit_config["min_fps"] = 1
dit_config["max_fps"] = 30
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 2048:
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 16
elif dit_config["model_channels"] == 5120:
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
if dit_config["in_channels"] == 16:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
dit_config["extra_t_extrapolation_ratio"] = 1.0
dit_config["rope_enable_fps_modulation"] = False
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
if percent >= 1.0:
return 0.0
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
def timestep(self, sigma):
return sigma / (sigma + 1)
def sigma(self, timestep):
sigma_max = self.sigma_max
if timestep >= (sigma_max / (sigma_max + 1)):
return sigma_max
return timestep / (1 - timestep)

View File

@ -908,6 +908,44 @@ class CosmosI2V(CosmosT2V):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
class CosmosT2IPredict2(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 16,
}
sampling_settings = {
"sigma_data": 1.0,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 17,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
return out
class Lumina2(supported_models_base.BASE):
unet_config = {
"image_model": "lumina2",
@ -1139,6 +1177,6 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]