Wan code small cleanup.

This commit is contained in:
comfyanonymous 2025-02-27 07:22:42 -05:00
parent b07f116dea
commit f4dac8ab6f

View File

@ -212,14 +212,10 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, e):
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, e)
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
@ -442,7 +438,6 @@ class WanModel(torch.nn.Module):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
# return [u.float() for u in x]
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape