diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 49880b606..546ebb225 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -4,9 +4,11 @@ import math import torch import torch.nn as nn +from einops import repeat from comfy.ldm.modules.attention import optimized_attention - +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope def sinusoidal_embedding_1d(dim, position): # preprocess @@ -21,45 +23,6 @@ def sinusoidal_embedding_1d(dim, position): return x -def rope_params(max_seq_len, dim, theta=10000): - assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs - - -def rope_apply(x, grid_sizes, freqs): - n, c = x.size(2), x.size(3) // 2 - - # split freqs - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) - - # loop over samples - output = [] - for i, (f, h, w) in enumerate(grid_sizes.tolist()): - seq_len = f * h * w - - # precompute multipliers - x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( - seq_len, n, -1, 2)) - freqs_i = torch.cat([ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(seq_len, 1, -1) - - # apply rotary embedding - x_i = torch.view_as_real(x_i * freqs_i).flatten(2) - x_i = torch.cat([x_i, x[i, seq_len:]]) - - # append to collection - output.append(x_i) - return torch.stack(output).to(dtype=x.dtype) - - class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5, device=None, dtype=None): @@ -122,10 +85,11 @@ class WanSelfAttention(nn.Module): return q, k, v q, k, v = qkv_fn(x) + q, k = apply_rope(q, k, freqs) x = optimized_attention( - q=rope_apply(q, grid_sizes, freqs).view(b, s, n * d), - k=rope_apply(k, grid_sizes, freqs).view(b, s, n * d), + q=q.view(b, s, n * d), + k=k.view(b, s, n * d), v=v, heads=self.num_heads, ) @@ -433,14 +397,8 @@ class WanModel(torch.nn.Module): # head self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) - # buffers (don't use register_buffer otherwise dtype will be changed in to()) - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads - self.register_buffer("freqs", torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], dim=1), persistent=False) + self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) @@ -453,6 +411,7 @@ class WanModel(torch.nn.Module): seq_len=None, clip_fea=None, y=None, + freqs=None, ): r""" Forward pass through the diffusion model @@ -477,10 +436,6 @@ class WanModel(torch.nn.Module): """ if self.model_type == 'i2v': assert clip_fea is not None and y is not None - # params - # device = self.patch_embedding.weight.device - # if self.freqs.device != device: - # self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -523,7 +478,7 @@ class WanModel(torch.nn.Module): e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, - freqs=self.freqs, + freqs=freqs, context=context, context_lens=context_lens) @@ -538,8 +493,20 @@ class WanModel(torch.nn.Module): return x # return [u.float() for u in x] - def forward(self, x, t, context, y=None, image=None, **kwargs): - return self.forward_orig([x], t, [context], clip_fea=y, y=image)[0] + def forward(self, x, timestep, context, y=None, image=None, **kwargs): + bs, c, t, h, w = x.shape + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) + img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + + freqs = self.rope_embedder(img_ids).movedim(1, 2) + return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0] def unpatchify(self, x, grid_sizes): r"""