mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-04 22:37:15 +08:00
Basic initial support for cosmos predict2 text to image 2B and 14B models. (#8517)
This commit is contained in:
parent
c6529c0d77
commit
251f54a2ad
@ -26,16 +26,6 @@ from torch import nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
|
@ -66,6 +66,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
h_extrapolation_ratio: float = 1.0,
|
||||
w_extrapolation_ratio: float = 1.0,
|
||||
t_extrapolation_ratio: float = 1.0,
|
||||
enable_fps_modulation: bool = True,
|
||||
device=None,
|
||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||
):
|
||||
@ -75,6 +76,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
self.base_fps = base_fps
|
||||
self.max_h = len_h
|
||||
self.max_w = len_w
|
||||
self.enable_fps_modulation = enable_fps_modulation
|
||||
|
||||
dim = head_dim
|
||||
dim_h = dim // 6 * 2
|
||||
@ -143,7 +145,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||
|
||||
# apply sequence scaling in temporal dimension
|
||||
if fps is None: # image case
|
||||
if fps is None or self.enable_fps_modulation is False: # image case
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||
else:
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
|
868
comfy/ldm/cosmos/predict2.py
Normal file
868
comfy/ldm/cosmos/predict2.py
Normal file
@ -0,0 +1,868 @@
|
||||
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple
|
||||
import math
|
||||
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
class GPT2FeedForward(nn.Module):
|
||||
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.activation = nn.GELU()
|
||||
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
||||
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self._layer_id = None
|
||||
self._dim = d_model
|
||||
self._hidden_dim = d_ff
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.layer1(x)
|
||||
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
||||
attention, and rearranges the output back to the original format.
|
||||
|
||||
The input tensor names use the following dimension conventions:
|
||||
|
||||
- B: batch size
|
||||
- S: sequence length
|
||||
- H: number of attention heads
|
||||
- D: head dimension
|
||||
|
||||
Args:
|
||||
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
|
||||
Returns:
|
||||
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
||||
"""
|
||||
in_q_shape = q_B_S_H_D.shape
|
||||
in_k_shape = k_B_S_H_D.shape
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
result_B_S_HD = rearrange(
|
||||
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
|
||||
)
|
||||
|
||||
return result_B_S_HD
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
||||
|
||||
This module implements a multi-head attention layer that can operate in either self-attention
|
||||
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
||||
The implementation uses scaled dot-product attention and supports optional bias terms and
|
||||
dropout regularization.
|
||||
|
||||
Args:
|
||||
query_dim (int): The dimensionality of the query vectors.
|
||||
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
||||
If None, the module operates in self-attention mode using query_dim. Default: None
|
||||
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
||||
head_dim (int, optional): The dimension of each attention head. Default: 64
|
||||
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
||||
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
||||
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
||||
|
||||
Examples:
|
||||
>>> # Self-attention with 512 dimensions and 8 heads
|
||||
>>> self_attn = Attention(query_dim=512)
|
||||
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
||||
>>> out = self_attn(x) # (32, 16, 512)
|
||||
|
||||
>>> # Cross-attention
|
||||
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
||||
>>> query = torch.randn(32, 16, 512)
|
||||
>>> context = torch.randn(32, 8, 256)
|
||||
>>> out = cross_attn(query, context) # (32, 16, 512)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim: Optional[int] = None,
|
||||
n_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{n_heads} heads with a dimension of {head_dim}."
|
||||
)
|
||||
self.is_selfattn = context_dim is None # self attention
|
||||
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
inner_dim = head_dim * n_heads
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_norm = nn.Identity()
|
||||
|
||||
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||
|
||||
self.attn_op = torch_attention_op
|
||||
|
||||
self._query_dim = query_dim
|
||||
self._context_dim = context_dim
|
||||
self._inner_dim = inner_dim
|
||||
|
||||
def compute_qkv(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_proj(x)
|
||||
context = x if context is None else context
|
||||
k = self.k_proj(context)
|
||||
v = self.v_proj(context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
def apply_norm_and_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
v = self.v_norm(v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
||||
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
||||
timesteps = timesteps_B_T.flatten().float()
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - 0.0)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
|
||||
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.in_dim = in_features
|
||||
self.out_dim = out_features
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
||||
self.activation = nn.SiLU()
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
||||
else:
|
||||
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
emb = self.linear_1(sample)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
|
||||
if self.use_adaln_lora:
|
||||
adaln_lora_B_T_3D = emb
|
||||
emb_B_T_D = sample
|
||||
else:
|
||||
adaln_lora_B_T_3D = None
|
||||
emb_B_T_D = emb
|
||||
|
||||
return emb_B_T_D, adaln_lora_B_T_3D
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||
and embedding each patch into a vector of size `out_channels`.
|
||||
|
||||
Parameters:
|
||||
- spatial_patch_size (int): The size of each spatial patch.
|
||||
- temporal_patch_size (int): The size of each temporal patch.
|
||||
- in_channels (int): Number of input channels. Default: 3.
|
||||
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 768,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
Rearrange(
|
||||
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||
r=temporal_patch_size,
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
operations.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
||||
),
|
||||
)
|
||||
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the PatchEmbed module.
|
||||
|
||||
Parameters:
|
||||
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||
B is the batch size,
|
||||
C is the number of channels,
|
||||
T is the temporal dimension,
|
||||
H is the height, and
|
||||
W is the width of the input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||
"""
|
||||
assert x.dim() == 5
|
||||
_, _, T, H, W = x.shape
|
||||
assert (
|
||||
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
||||
assert T % self.temporal_patch_size == 0
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of video DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
out_channels: int,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.n_adaln_chunks = 2
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
if use_adaln_lora:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
||||
scale_B_T_D, "b t d -> b t 1 1 d"
|
||||
)
|
||||
|
||||
def _fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
_norm_layer: nn.Module,
|
||||
_scale_B_T_1_1_D: torch.Tensor,
|
||||
_shift_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
||||
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
||||
return x_B_T_H_W_O
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
||||
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (int): Dimension of context features for cross-attention
|
||||
num_heads (int): Number of attention heads
|
||||
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
||||
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
||||
|
||||
The block applies the following sequence:
|
||||
1. Self-attention with AdaLN modulation
|
||||
2. Cross-attention with AdaLN modulation
|
||||
3. MLP with AdaLN modulation
|
||||
|
||||
Each component uses skip connections and layer normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.x_dim = x_dim
|
||||
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if self.use_adaln_lora:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
||||
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
|
||||
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
||||
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
B, T, H, W, D = x_B_T_H_W_D.shape
|
||||
|
||||
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_self_attn,
|
||||
scale_self_attn_B_T_1_1_D,
|
||||
shift_self_attn_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
return _result_B_T_H_W_D
|
||||
|
||||
result_B_T_H_W_D = _x_fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_mlp,
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
class MiniTrainDIT(nn.Module):
|
||||
"""
|
||||
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
|
||||
Args:
|
||||
max_img_h (int): Maximum height of the input images.
|
||||
max_img_w (int): Maximum width of the input images.
|
||||
max_frames (int): Maximum number of frames in the video sequence.
|
||||
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||
out_channels (int): Number of output channels.
|
||||
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||
model_channels (int): Base number of channels used throughout the model.
|
||||
num_blocks (int): Number of transformer blocks.
|
||||
num_heads (int): Number of heads in the multi-head attention layers.
|
||||
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||
pos_emb_cls (str): Type of positional embeddings.
|
||||
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||
min_fps (int): Minimum frames per second.
|
||||
max_fps (int): Maximum frames per second.
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
max_img_w: int,
|
||||
max_frames: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_spatial: int, # tuple,
|
||||
patch_temporal: int,
|
||||
concat_padding_mask: bool = True,
|
||||
# attention settings
|
||||
model_channels: int = 768,
|
||||
num_blocks: int = 10,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
# cross attention settings
|
||||
crossattn_emb_channels: int = 1024,
|
||||
# positional embedding settings
|
||||
pos_emb_cls: str = "sincos",
|
||||
pos_emb_learnable: bool = False,
|
||||
pos_emb_interpolation: str = "crop",
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 30,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
rope_h_extrapolation_ratio: float = 1.0,
|
||||
rope_w_extrapolation_ratio: float = 1.0,
|
||||
rope_t_extrapolation_ratio: float = 1.0,
|
||||
extra_per_block_abs_pos_emb: bool = False,
|
||||
extra_h_extrapolation_ratio: float = 1.0,
|
||||
extra_w_extrapolation_ratio: float = 1.0,
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
rope_enable_fps_modulation: bool = True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_img_h = max_img_h
|
||||
self.max_img_w = max_img_w
|
||||
self.max_frames = max_frames
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_spatial = patch_spatial
|
||||
self.patch_temporal = patch_temporal
|
||||
self.num_heads = num_heads
|
||||
self.num_blocks = num_blocks
|
||||
self.model_channels = model_channels
|
||||
self.concat_padding_mask = concat_padding_mask
|
||||
# positional embedding settings
|
||||
self.pos_emb_cls = pos_emb_cls
|
||||
self.pos_emb_learnable = pos_emb_learnable
|
||||
self.pos_emb_interpolation = pos_emb_interpolation
|
||||
self.min_fps = min_fps
|
||||
self.max_fps = max_fps
|
||||
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.Sequential(
|
||||
Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
||||
)
|
||||
|
||||
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.x_embedder = PatchEmbed(
|
||||
spatial_patch_size=patch_spatial,
|
||||
temporal_patch_size=patch_temporal,
|
||||
in_channels=in_channels,
|
||||
out_channels=model_channels,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
x_dim=model_channels,
|
||||
context_dim=crossattn_emb_channels,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=self.model_channels,
|
||||
spatial_patch_size=self.patch_spatial,
|
||||
temporal_patch_size=self.patch_temporal,
|
||||
out_channels=self.out_channels,
|
||||
use_adaln_lora=self.use_adaln_lora,
|
||||
adaln_lora_dim=self.adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def build_pos_embed(self, device=None, dtype=None) -> None:
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
len_w=self.max_img_w // self.patch_spatial,
|
||||
len_t=self.max_frames // self.patch_temporal,
|
||||
max_fps=self.max_fps,
|
||||
min_fps=self.min_fps,
|
||||
is_learnable=self.pos_emb_learnable,
|
||||
interpolation=self.pos_emb_interpolation,
|
||||
head_dim=self.model_channels // self.num_heads,
|
||||
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||
enable_fps_modulation=self.rope_enable_fps_modulation,
|
||||
device=device,
|
||||
)
|
||||
self.pos_embedder = cls_type(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
def prepare_embedded_sequence(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||
|
||||
Args:
|
||||
x_B_C_T_H_W (torch.Tensor): video
|
||||
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||
If None, a default value (`self.base_fps`) will be used.
|
||||
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||
|
||||
Notes:
|
||||
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||
the `self.pos_embedder` with the shape [T, H, W].
|
||||
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||
`self.pos_embedder` with the fps tensor.
|
||||
- Otherwise, the positional embeddings are generated without considering fps.
|
||||
"""
|
||||
if self.concat_padding_mask:
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||
else:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||
else:
|
||||
extra_pos_emb = None
|
||||
|
||||
if "rope" in self.pos_emb_cls.lower():
|
||||
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
|
||||
return x_B_T_H_W_D, None, extra_pos_emb
|
||||
|
||||
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
||||
x_B_C_Tt_Hp_Wp = rearrange(
|
||||
x_B_T_H_W_M,
|
||||
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||
p1=self.patch_spatial,
|
||||
p2=self.patch_spatial,
|
||||
t=self.patch_temporal,
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
"""
|
||||
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||
x_B_C_T_H_W,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if timesteps_B_T.ndim == 1:
|
||||
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
||||
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
||||
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
||||
|
||||
# for logging purpose
|
||||
affline_scale_log_info = {}
|
||||
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
||||
self.affline_scale_log_info = affline_scale_log_info
|
||||
self.affline_emb = t_embedding_B_T_D
|
||||
self.crossattn_emb = crossattn_emb
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
||||
|
||||
block_kwargs = {
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
@ -34,6 +34,7 @@ import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
@ -48,6 +49,7 @@ import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@ -63,38 +65,39 @@ class ModelType(Enum):
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
FLOW_COSMOS = 10
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
s = comfy.model_sampling.ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
c = comfy.model_sampling.EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
c = comfy.model_sampling.EPS
|
||||
s = comfy.model_sampling.StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.EDM
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousV
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = comfy.model_sampling.IMG_TO_IMG
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -998,6 +1001,22 @@ class CosmosVideo(BaseModel):
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
class CosmosPredict2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
|
||||
self.image_to_video = image_to_video
|
||||
if self.image_to_video:
|
||||
self.concat_keys = ("mask_inverted",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||
return out
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
|
@ -407,6 +407,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["text_emb_dim"] = 2048
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos_predict2"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["crossattn_emb_channels"] = 1024
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
dit_config["pos_emb_learnable"] = True
|
||||
dit_config["pos_emb_interpolation"] = "crop"
|
||||
dit_config["min_fps"] = 1
|
||||
dit_config["max_fps"] = 30
|
||||
|
||||
dit_config["use_adaln_lora"] = True
|
||||
dit_config["adaln_lora_dim"] = 256
|
||||
if dit_config["model_channels"] == 2048:
|
||||
dit_config["num_blocks"] = 28
|
||||
dit_config["num_heads"] = 16
|
||||
elif dit_config["model_channels"] == 5120:
|
||||
dit_config["num_blocks"] = 36
|
||||
dit_config["num_heads"] = 40
|
||||
|
||||
if dit_config["in_channels"] == 16:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
elif dit_config["in_channels"] == 17:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
|
||||
dit_config["extra_h_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_w_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_t_extrapolation_ratio"] = 1.0
|
||||
dit_config["rope_enable_fps_modulation"] = False
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise * (1.0 - sigma)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * (1.0 - sigma) - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
||||
|
||||
|
||||
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
|
||||
def timestep(self, sigma):
|
||||
return sigma / (sigma + 1)
|
||||
|
||||
def sigma(self, timestep):
|
||||
sigma_max = self.sigma_max
|
||||
if timestep >= (sigma_max / (sigma_max + 1)):
|
||||
return sigma_max
|
||||
|
||||
return timestep / (1 - timestep)
|
||||
|
@ -908,6 +908,44 @@ class CosmosI2V(CosmosT2V):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
"in_channels": 16,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"sigma_data": 1.0,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.6 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
"in_channels": 17,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Lumina2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
@ -1139,6 +1177,6 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
Loading…
x
Reference in New Issue
Block a user