mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-16 06:27:15 +00:00
Make the SkipLayerGuidanceDIT node work on WAN.
This commit is contained in:
parent
9c98c6358b
commit
6a0daa79b6
@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
|
|||||||
context,
|
context,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -429,8 +430,18 @@ class WanModel(torch.nn.Module):
|
|||||||
freqs=freqs,
|
freqs=freqs,
|
||||||
context=context)
|
context=context)
|
||||||
|
|
||||||
for block in self.blocks:
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
x = block(x, **kwargs)
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@ -439,7 +450,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@ -453,7 +464,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
Loading…
Reference in New Issue
Block a user