diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 1fb68664..1acaf15b 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -40,6 +40,32 @@ class CFGZeroStar: m.set_model_sampler_post_cfg_function(cfg_zero_star) return (m, ) +class CFGNorm: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model, strength): + m = model.clone() + def cfg_norm(args): + cond_p = args['cond_denoised'] + pred_text_ = args["denoised"] + + norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + return pred_text_ * scale * strength + + m.set_model_sampler_post_cfg_function(cfg_norm) + return (m, ) + NODE_CLASS_MAPPINGS = { - "CFGZeroStar": CFGZeroStar + "CFGZeroStar": CFGZeroStar, + "CFGNorm": CFGNorm, }