Wan code small cleanup.

This commit is contained in:
comfyanonymous 2025-02-27 07:22:42 -05:00
parent b07f116dea
commit f4dac8ab6f

View File

@ -212,14 +212,10 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2] x = x + y * e[2]
# cross-attention & ffn function # cross-attention & ffn
def cross_attn_ffn(x, context, e): x = x + self.cross_attn(self.norm3(x), context)
x = x + self.cross_attn(self.norm3(x), context) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5]
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, e)
return x return x
@ -442,7 +438,6 @@ class WanModel(torch.nn.Module):
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
# return [u.float() for u in x]
def forward(self, x, timestep, context, clip_fea=None, **kwargs): def forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape