diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 1b51a4e4a..1d6edb354 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -539,13 +539,20 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index cfd10d726..8ed124277 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1057,6 +1057,11 @@ class WAN21(BaseModel): clip_vision_output = kwargs.get("clip_vision_output", None) if clip_vision_output is not None: out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) + + time_dim_concat = kwargs.get("time_dim_concat", None) + if time_dim_concat is not None: + out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + return out diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index c35c4871c..d6097a104 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -345,6 +345,44 @@ class WanCameraImageToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanPhantomSubjectToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"images": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, images): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond2 = negative + if images is not None: + images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + latent_images = [] + for i in images: + latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])] + concat_latent_image = torch.cat(latent_images, dim=2) + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image}) + cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, cond2, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, @@ -353,4 +391,5 @@ NODE_CLASS_MAPPINGS = { "WanVaceToVideo": WanVaceToVideo, "TrimVideoLatent": TrimVideoLatent, "WanCameraImageToVideo": WanCameraImageToVideo, + "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, }