From 4136502b7a530df28e489f81c0c58b30612f9ece Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Mon, 12 May 2025 18:10:24 -0700 Subject: [PATCH] implement APG guidance (#8081) * first pass at impementing AGP * rename, cleanup code * fix ruff * fix modified cond to match ref impl better, support different cond arity --- comfy_extras/nodes_apg.py | 76 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 77 insertions(+) create mode 100644 comfy_extras/nodes_apg.py diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py new file mode 100644 index 000000000..1325985b2 --- /dev/null +++ b/comfy_extras/nodes_apg.py @@ -0,0 +1,76 @@ +import torch + +def project(v0, v1): + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel, v0_orthogonal + +class APG: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), + "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), + "momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "sampling/custom_sampling" + + def patch(self, model, eta, norm_threshold, momentum): + running_avg = 0 + prev_sigma = None + + def pre_cfg_function(args): + nonlocal running_avg, prev_sigma + + if len(args["conds_out"]) == 1: return args["conds_out"] + + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + sigma = args["sigma"][0] + cond_scale = args["cond_scale"] + + if prev_sigma is not None and sigma > prev_sigma: + running_avg = 0 + prev_sigma = sigma + + guidance = cond - uncond + + if momentum > 0: + if not torch.is_tensor(running_avg): + running_avg = guidance + else: + running_avg = momentum * running_avg + guidance + guidance = running_avg + + if norm_threshold > 0: + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale = torch.minimum( + torch.ones_like(guidance_norm), + norm_threshold / guidance_norm + ) + guidance = guidance * scale + + guidance_parallel, guidance_orthogonal = project(guidance, cond) + modified_guidance = guidance_orthogonal + eta * guidance_parallel + + modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale + + return [modified_cond, uncond] + args["conds_out"][2:] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(pre_cfg_function) + return (m,) + +NODE_CLASS_MAPPINGS = { + "APG": APG, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "APG": "Adaptive Projected Guidance", +} diff --git a/nodes.py b/nodes.py index a26a138fa..54e3886a3 100644 --- a/nodes.py +++ b/nodes.py @@ -2261,6 +2261,7 @@ def init_builtin_extra_nodes(): "nodes_optimalsteps.py", "nodes_hidream.py", "nodes_fresca.py", + "nodes_apg.py", "nodes_preview_any.py", "nodes_ace.py", "nodes_string.py",