diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5471763bb..e78d846b2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -212,14 +212,10 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] - # cross-attention & ffn function - def cross_attn_ffn(x, context, e): - x = x + self.cross_attn(self.norm3(x), context) - y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) - x = x + y * e[5] - return x - - x = cross_attn_ffn(x, context, e) + # cross-attention & ffn + x = x + self.cross_attn(self.norm3(x), context) + y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) + x = x + y * e[5] return x @@ -442,7 +438,6 @@ class WanModel(torch.nn.Module): # unpatchify x = self.unpatchify(x, grid_sizes) return x - # return [u.float() for u in x] def forward(self, x, timestep, context, clip_fea=None, **kwargs): bs, c, t, h, w = x.shape