From f85c08df0615a587e0974678b01199b88a1caae0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 22 May 2025 16:22:26 -0700 Subject: [PATCH] Make VACE conditionings stackable. (#8240) --- comfy/ldm/wan/model.py | 10 +++++++--- comfy/model_base.py | 21 +++++++++++++-------- comfy_extras/nodes_wan.py | 5 +++-- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a996dedf4..1b51a4e4a 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -635,7 +635,7 @@ class VaceWanModel(WanModel): t, context, vace_context, - vace_strength=1.0, + vace_strength, clip_fea=None, freqs=None, transformer_options={}, @@ -661,8 +661,11 @@ class VaceWanModel(WanModel): context = torch.concat([context_clip, context], dim=1) context_img_len = clip_fea.shape[-2] + orig_shape = list(vace_context.shape) + vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:]) c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) c = c.flatten(2).transpose(1, 2) + c = list(c.split(orig_shape[0], dim=0)) # arguments x_orig = x @@ -682,8 +685,9 @@ class VaceWanModel(WanModel): ii = self.vace_layers_mapping.get(i, None) if ii is not None: - c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) - x += c_skip * vace_strength + for iii in range(len(c)): + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip * vace_strength[iii] del c_skip # head x = self.head(x, e) diff --git a/comfy/model_base.py b/comfy/model_base.py index f475e837e..fb4724690 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1062,20 +1062,25 @@ class WAN21_Vace(WAN21): vace_frames = kwargs.get("vace_frames", None) if vace_frames is None: noise_shape[1] = 32 - vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) - - for i in range(0, vace_frames.shape[1], 16): - vace_frames = vace_frames.clone() - vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16]) + vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)] mask = kwargs.get("vace_mask", None) if mask is None: noise_shape[1] = 64 - mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) + mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames) - out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + vace_frames_out = [] + for j in range(len(vace_frames)): + vf = vace_frames[j].clone() + for i in range(0, vf.shape[1], 16): + vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) + vf = torch.cat([vf, mask[j]], dim=1) + vace_frames_out.append(vf) - vace_strength = kwargs.get("vace_strength", 1.0) + vace_frames = torch.stack(vace_frames_out, dim=1) + out['vace_context'] = comfy.conds.CONDRegular(vace_frames) + + vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out)) out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index a91b4aba9..c35c4871c 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -268,8 +268,9 @@ class WanVaceToVideo: trim_latent = reference_image.shape[2] mask = mask.unsqueeze(0) - positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) - negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) + + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True) latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {}