Don't hardcode length of context_img in wan code.

This commit is contained in:
comfyanonymous 2025-04-17 06:25:39 -04:00
parent 1fc00ba4b6
commit 0d720e4367

View File

@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
def forward(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context):
def forward(self, x, context, context_img_len):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :257]
context = context[:, 257:]
context_img = context[:, :context_img_len]
context = context[:, context_img_len:]
# compute query, key, value
q = self.norm_q(self.q(x))
@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
e,
freqs,
context,
context_img_len=None,
):
r"""
Args:
@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
@ -420,9 +421,12 @@ class WanModel(torch.nn.Module):
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@ -430,12 +434,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)