From 884ea653c8d6fe19b3724f45a04a0d74cd881f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 11:05:15 -0400 Subject: [PATCH] Add a way for nodes to set a custom CFG function. --- comfy/samplers.py | 5 ++++- comfy/sd.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index ed36442a..05af6fe8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -211,7 +211,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) - return uncond + (cond - uncond) * cond_scale + if "sampler_cfg_function" in model_options: + return model_options["sampler_cfg_function"](cond, uncond, cond_scale) + else: + return uncond + (cond - uncond) * cond_scale class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): diff --git a/comfy/sd.py b/comfy/sd.py index 9c632e24..1d777474 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -250,6 +250,9 @@ class ModelPatcher: def set_model_tomesd(self, ratio): self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} + def set_model_sampler_cfg_function(self, sampler_cfg_function): + self.model_options["sampler_cfg_function"] = sampler_cfg_function + def model_dtype(self): return self.model.diffusion_model.dtype