mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Add DPM-Solver++(2M) SDE and exponential scheduler.
exponential scheduler is the one recommended with this sampler.
This commit is contained in:
parent
516119ad83
commit
069657fbf3
@ -605,3 +605,46 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
|
"""DPM-Solver++(2M) SDE."""
|
||||||
|
|
||||||
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
h_last = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2M) SDE
|
||||||
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
|
h = s - t
|
||||||
|
eta_h = eta * h
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||||
|
|
||||||
|
if old_denoised is not None:
|
||||||
|
r = h_last / h
|
||||||
|
if solver_type == 'heun':
|
||||||
|
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||||
|
elif solver_type == 'midpoint':
|
||||||
|
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||||
|
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
|
old_denoised = denoised
|
||||||
|
h_last = h
|
||||||
|
return x
|
||||||
|
@ -495,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|||||||
|
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -532,6 +532,8 @@ class KSampler:
|
|||||||
|
|
||||||
if self.scheduler == "karras":
|
if self.scheduler == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
|
elif self.scheduler == "exponential":
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
elif self.scheduler == "normal":
|
elif self.scheduler == "normal":
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
elif self.scheduler == "simple":
|
elif self.scheduler == "simple":
|
||||||
|
Loading…
Reference in New Issue
Block a user