Change wan rope implementation to the flux one.

Should be more compatible.
This commit is contained in:
comfyanonymous 2025-02-25 19:11:14 -05:00
parent 63023011b9
commit f37551c1d2

View File

@ -4,9 +4,11 @@ import math
import torch
import torch.nn as nn
from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
def sinusoidal_embedding_1d(dim, position):
# preprocess
@ -21,45 +23,6 @@ def sinusoidal_embedding_1d(dim, position):
return x
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).to(dtype=x.dtype)
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5, device=None, dtype=None):
@ -122,10 +85,11 @@ class WanSelfAttention(nn.Module):
return q, k, v
q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
x = optimized_attention(
q=rope_apply(q, grid_sizes, freqs).view(b, s, n * d),
k=rope_apply(k, grid_sizes, freqs).view(b, s, n * d),
q=q.view(b, s, n * d),
k=k.view(b, s, n * d),
v=v,
heads=self.num_heads,
)
@ -433,14 +397,8 @@ class WanModel(torch.nn.Module):
# head
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.register_buffer("freqs", torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], dim=1), persistent=False)
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
@ -453,6 +411,7 @@ class WanModel(torch.nn.Module):
seq_len=None,
clip_fea=None,
y=None,
freqs=None,
):
r"""
Forward pass through the diffusion model
@ -477,10 +436,6 @@ class WanModel(torch.nn.Module):
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
# device = self.patch_embedding.weight.device
# if self.freqs.device != device:
# self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -523,7 +478,7 @@ class WanModel(torch.nn.Module):
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
freqs=freqs,
context=context,
context_lens=context_lens)
@ -538,8 +493,20 @@ class WanModel(torch.nn.Module):
return x
# return [u.float() for u in x]
def forward(self, x, t, context, y=None, image=None, **kwargs):
return self.forward_orig([x], t, [context], clip_fea=y, y=image)[0]
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0]
def unpatchify(self, x, grid_sizes):
r"""