From 84fdaf7b0ef4d030723bc3b350282dc6c92743f6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Mar 2025 05:08:49 -0400 Subject: [PATCH] Add CFGZeroStar node. Works on all models that use a negative prompt but is meant for rectified flow models. --- comfy_extras/nodes_cfg.py | 45 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 46 insertions(+) create mode 100644 comfy_extras/nodes_cfg.py diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 00000000..1fb68664 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/nodes.py b/nodes.py index 27ef743b..272c2a25 100644 --- a/nodes.py +++ b/nodes.py @@ -2267,6 +2267,7 @@ def init_builtin_extra_nodes(): "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", + "nodes_cfg.py", ] import_failed = []