mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Implement my alternative take on CFG++ as the euler_pp sampler.
Add euler_ancestral_pp which is the ancestral version of euler with the same modification.
This commit is contained in:
parent
90aebb6c86
commit
69d710e40f
@ -7,7 +7,7 @@ import torchsde
|
|||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
import comfy.model_patcher
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
@ -945,3 +945,56 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
buffer_model.append(d_cur.detach())
|
buffer_model.append(d_cur.detach())
|
||||||
|
|
||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
|
d = to_d(x - denoised + temp[0], sigma_hat, denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_ancestral_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with Euler method steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
d = to_d(x - denoised + temp[0], sigmas[i], denoised)
|
||||||
|
# Euler method
|
||||||
|
dt = sigma_down - sigmas[i]
|
||||||
|
x = x + d * dt
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
return x
|
||||||
|
@ -537,7 +537,7 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_pp", "euler_ancestral", "euler_ancestral_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v"]
|
"ipndm", "ipndm_v"]
|
||||||
|
@ -82,29 +82,6 @@ def sample_euler_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable
|
|||||||
x = denoised + sigmas[i + 1] * d
|
x = denoised + sigmas[i + 1] * d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_euler_cfgpp_alt(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
|
|
||||||
temp = [0]
|
|
||||||
def post_cfg_function(args):
|
|
||||||
temp[0] = args["uncond_denoised"]
|
|
||||||
return args["denoised"]
|
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
sigma_hat = sigmas[i]
|
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
|
||||||
d = to_d(x - denoised + temp[0], sigma_hat, denoised)
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
|
||||||
dt = sigmas[i + 1] - sigma_hat
|
|
||||||
# Euler method
|
|
||||||
x = x + d * dt
|
|
||||||
return x
|
|
||||||
|
|
||||||
class SamplerEulerCFGpp:
|
class SamplerEulerCFGpp:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -122,7 +99,7 @@ class SamplerEulerCFGpp:
|
|||||||
if version == "regular":
|
if version == "regular":
|
||||||
sampler = comfy.samplers.KSAMPLER(sample_euler_cfgpp)
|
sampler = comfy.samplers.KSAMPLER(sample_euler_cfgpp)
|
||||||
else:
|
else:
|
||||||
sampler = comfy.samplers.KSAMPLER(sample_euler_cfgpp_alt)
|
sampler = comfy.samplers.ksampler("euler_pp")
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
Loading…
Reference in New Issue
Block a user