From 84fdaf7b0ef4d030723bc3b350282dc6c92743f6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Mar 2025 05:08:49 -0400 Subject: [PATCH 1/3] Add CFGZeroStar node. Works on all models that use a negative prompt but is meant for rectified flow models. --- comfy_extras/nodes_cfg.py | 45 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 46 insertions(+) create mode 100644 comfy_extras/nodes_cfg.py diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 00000000..1fb68664 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/nodes.py b/nodes.py index 27ef743b..272c2a25 100644 --- a/nodes.py +++ b/nodes.py @@ -2267,6 +2267,7 @@ def init_builtin_extra_nodes(): "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", + "nodes_cfg.py", ] import_failed = [] From 3661c833bcc41b788a7c9f0e7bc48524f8ee5f82 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Mar 2025 19:54:54 -0400 Subject: [PATCH 2/3] Support the WAN 2.1 fun control models. Use the new WanFunControlToVideo node. --- comfy/model_base.py | 17 ++++++++----- comfy/supported_models.py | 14 ++++++++++- comfy_extras/nodes_wan.py | 51 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index eec70d5d..315b5d1e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -992,7 +992,8 @@ class WAN21(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) @@ -1000,12 +1001,16 @@ class WAN21(BaseModel): if image is None: image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) - - if not self.image_to_video: + if not self.image_to_video or extra_channels == image.shape[1]: return image mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fad00d35..2a6a6156 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=True, device=device) return out +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1013,6 +1025,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index dc30eb54..428874bc 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class WanImageToVideo: @@ -49,6 +50,56 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, } From 0a1f8869c9998bbfcfeb2e97aa96a6d3e0a2b5df Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 27 Mar 2025 11:13:27 -0400 Subject: [PATCH 3/3] Add WanFunInpaintToVideo node for the Wan fun inpaint models. --- comfy/model_base.py | 7 +++-- comfy_extras/nodes_wan.py | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 315b5d1e..8f588e2b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1017,11 +1017,14 @@ class WAN21(BaseModel): if mask is None: mask = torch.zeros_like(noise)[:, :4] else: - mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.repeat(1, 4, 1, 1, 1) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((mask, image), dim=1) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 428874bc..2d0f31ac 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -99,7 +99,61 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanFunInpaintToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, + "WanFunInpaintToVideo": WanFunInpaintToVideo, }