mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
331 lines
12 KiB
Python
331 lines
12 KiB
Python
#Based on Flux code because of weird hunyuan video code license.
|
|
|
|
import torch
|
|
import comfy.ldm.flux.layers
|
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
|
|
from dataclasses import dataclass
|
|
from einops import repeat
|
|
|
|
from torch import Tensor, nn
|
|
|
|
from comfy.ldm.flux.layers import (
|
|
DoubleStreamBlock,
|
|
EmbedND,
|
|
LastLayer,
|
|
MLPEmbedder,
|
|
SingleStreamBlock,
|
|
timestep_embedding
|
|
)
|
|
|
|
import comfy.ldm.common_dit
|
|
|
|
|
|
@dataclass
|
|
class HunyuanVideoParams:
|
|
in_channels: int
|
|
out_channels: int
|
|
vec_in_dim: int
|
|
context_in_dim: int
|
|
hidden_size: int
|
|
mlp_ratio: float
|
|
num_heads: int
|
|
depth: int
|
|
depth_single_blocks: int
|
|
axes_dim: list
|
|
theta: int
|
|
patch_size: list
|
|
qkv_bias: bool
|
|
guidance_embed: bool
|
|
|
|
|
|
class SelfAttentionRef(nn.Module):
|
|
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
|
|
|
|
|
class TokenRefinerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
heads,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.heads = heads
|
|
mlp_hidden_dim = hidden_size * 4
|
|
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
|
|
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
|
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
|
|
|
self.mlp = nn.Sequential(
|
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
|
|
def forward(self, x, c, mask):
|
|
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
|
|
norm_x = self.norm1(x)
|
|
qkv = self.self_attn.qkv(norm_x)
|
|
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
|
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
|
|
|
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
|
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
|
return x
|
|
|
|
|
|
class IndividualTokenRefiner(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
heads,
|
|
num_blocks,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
TokenRefinerBlock(
|
|
hidden_size=hidden_size,
|
|
heads=heads,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations
|
|
)
|
|
for _ in range(num_blocks)
|
|
]
|
|
)
|
|
|
|
def forward(self, x, c, mask):
|
|
m = None
|
|
if mask is not None:
|
|
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
|
m = m + m.transpose(2, 3)
|
|
|
|
for block in self.blocks:
|
|
x = block(x, c, m)
|
|
return x
|
|
|
|
|
|
|
|
class TokenRefiner(nn.Module):
|
|
def __init__(
|
|
self,
|
|
text_dim,
|
|
hidden_size,
|
|
heads,
|
|
num_blocks,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
|
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
timesteps,
|
|
mask,
|
|
):
|
|
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
|
# m = mask.float().unsqueeze(-1)
|
|
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
|
c = x.sum(dim=1) / x.shape[1]
|
|
|
|
c = t + self.c_embedder(c.to(x.dtype))
|
|
x = self.input_embedder(x)
|
|
x = self.individual_token_refiner(x, c, mask)
|
|
return x
|
|
|
|
class HunyuanVideo(nn.Module):
|
|
"""
|
|
Transformer model for flow matching on sequences.
|
|
"""
|
|
|
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
params = HunyuanVideoParams(**kwargs)
|
|
self.params = params
|
|
self.patch_size = params.patch_size
|
|
self.in_channels = params.in_channels
|
|
self.out_channels = params.out_channels
|
|
if params.hidden_size % params.num_heads != 0:
|
|
raise ValueError(
|
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
|
)
|
|
pe_dim = params.hidden_size // params.num_heads
|
|
if sum(params.axes_dim) != pe_dim:
|
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
|
self.hidden_size = params.hidden_size
|
|
self.num_heads = params.num_heads
|
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
|
|
|
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.guidance_in = (
|
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
|
)
|
|
|
|
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.double_blocks = nn.ModuleList(
|
|
[
|
|
DoubleStreamBlock(
|
|
self.hidden_size,
|
|
self.num_heads,
|
|
mlp_ratio=params.mlp_ratio,
|
|
qkv_bias=params.qkv_bias,
|
|
flipped_img_txt=True,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
for _ in range(params.depth)
|
|
]
|
|
)
|
|
|
|
self.single_blocks = nn.ModuleList(
|
|
[
|
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(params.depth_single_blocks)
|
|
]
|
|
)
|
|
|
|
if final_layer:
|
|
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward_orig(
|
|
self,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
txt_mask: Tensor,
|
|
timesteps: Tensor,
|
|
y: Tensor,
|
|
guidance: Tensor = None,
|
|
control=None,
|
|
transformer_options={},
|
|
) -> Tensor:
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
|
|
initial_shape = list(img.shape)
|
|
# running on sequences img
|
|
img = self.img_in(img)
|
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
|
|
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
|
|
|
if self.params.guidance_embed:
|
|
if guidance is None:
|
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
|
|
|
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
|
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
|
|
|
txt = self.txt_in(txt, timesteps, txt_mask)
|
|
|
|
ids = torch.cat((img_ids, txt_ids), dim=1)
|
|
pe = self.pe_embedder(ids)
|
|
|
|
img_len = img.shape[1]
|
|
if txt_mask is not None:
|
|
attn_mask_len = img_len + txt.shape[1]
|
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
|
attn_mask[:, 0, img_len:] = txt_mask
|
|
else:
|
|
attn_mask = None
|
|
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
for i, block in enumerate(self.double_blocks):
|
|
if ("double_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
|
return out
|
|
|
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
|
txt = out["txt"]
|
|
img = out["img"]
|
|
else:
|
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
|
|
|
if control is not None: # Controlnet
|
|
control_i = control.get("input")
|
|
if i < len(control_i):
|
|
add = control_i[i]
|
|
if add is not None:
|
|
img += add
|
|
|
|
img = torch.cat((img, txt), 1)
|
|
|
|
for i, block in enumerate(self.single_blocks):
|
|
if ("single_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
|
return out
|
|
|
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
|
img = out["img"]
|
|
else:
|
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
|
|
|
if control is not None: # Controlnet
|
|
control_o = control.get("output")
|
|
if i < len(control_o):
|
|
add = control_o[i]
|
|
if add is not None:
|
|
img[:, : img_len] += add
|
|
|
|
img = img[:, : img_len]
|
|
|
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
|
|
shape = initial_shape[-3:]
|
|
for i in range(len(shape)):
|
|
shape[i] = shape[i] // self.patch_size[i]
|
|
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
|
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
|
img = img.reshape(initial_shape)
|
|
return img
|
|
|
|
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
|
bs, c, t, h, w = x.shape
|
|
patch_size = self.patch_size
|
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
|
return out
|