mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Change wan rope implementation to the flux one.
Should be more compatible.
This commit is contained in:
parent
63023011b9
commit
f37551c1d2
@ -4,9 +4,11 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
@ -21,45 +23,6 @@ def sinusoidal_embedding_1d(dim, position):
|
||||
return x
|
||||
|
||||
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta,
|
||||
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# precompute multipliers
|
||||
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
||||
seq_len, n, -1, 2))
|
||||
freqs_i = torch.cat([
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(seq_len, 1, -1)
|
||||
|
||||
# apply rotary embedding
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5, device=None, dtype=None):
|
||||
@ -122,10 +85,11 @@ class WanSelfAttention(nn.Module):
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
|
||||
x = optimized_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs).view(b, s, n * d),
|
||||
k=rope_apply(k, grid_sizes, freqs).view(b, s, n * d),
|
||||
q=q.view(b, s, n * d),
|
||||
k=k.view(b, s, n * d),
|
||||
v=v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
@ -433,14 +397,8 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
||||
|
||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.register_buffer("freqs", torch.cat([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6))
|
||||
], dim=1), persistent=False)
|
||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||
@ -453,6 +411,7 @@ class WanModel(torch.nn.Module):
|
||||
seq_len=None,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
freqs=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@ -477,10 +436,6 @@ class WanModel(torch.nn.Module):
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
# params
|
||||
# device = self.patch_embedding.weight.device
|
||||
# if self.freqs.device != device:
|
||||
# self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
@ -523,7 +478,7 @@ class WanModel(torch.nn.Module):
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
freqs=freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
@ -538,8 +493,20 @@ class WanModel(torch.nn.Module):
|
||||
return x
|
||||
# return [u.float() for u in x]
|
||||
|
||||
def forward(self, x, t, context, y=None, image=None, **kwargs):
|
||||
return self.forward_orig([x], t, [context], clip_fea=y, y=image)[0]
|
||||
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
|
Loading…
Reference in New Issue
Block a user