diff --git a/comfy/model_base.py b/comfy/model_base.py index eec70d5de..315b5d1e3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -992,7 +992,8 @@ class WAN21(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) @@ -1000,12 +1001,16 @@ class WAN21(BaseModel): if image is None: image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) - - if not self.image_to_video: + if not self.image_to_video or extra_channels == image.shape[1]: return image mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fad00d35b..2a6a61560 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=True, device=device) return out +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1013,6 +1025,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index dc30eb546..428874bcc 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class WanImageToVideo: @@ -49,6 +50,56 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, }