mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-14 03:16:59 +08:00
Singlestep DPM++ SDE for RF (#8627)
Refactor the algorithm, and apply alpha scaling.
This commit is contained in:
parent
bd9f166c12
commit
8042eb20c6
@ -710,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
|||||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
@ -721,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Euler method
|
# Denoising step
|
||||||
d = to_d(x, sigmas[i], denoised)
|
x = denoised
|
||||||
dt = sigmas[i + 1] - sigmas[i]
|
|
||||||
x = x + d * dt
|
|
||||||
else:
|
else:
|
||||||
# DPM-Solver++
|
# DPM-Solver++
|
||||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = t_next - t
|
h = lambda_t - lambda_s
|
||||||
s = t + h * r
|
lambda_s_1 = lambda_s + r * h
|
||||||
fac = 1 / (2 * r)
|
fac = 1 / (2 * r)
|
||||||
|
|
||||||
|
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||||
|
|
||||||
|
alpha_s = sigmas[i] * lambda_s.exp()
|
||||||
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||||
s_ = t_fn(sd)
|
lambda_s_1_ = sd.log().neg()
|
||||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
h_ = lambda_s_1_ - lambda_s
|
||||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
if eta > 0 and s_noise > 0:
|
||||||
|
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||||
t_next_ = t_fn(sd)
|
lambda_t_ = sd.log().neg()
|
||||||
|
h_ = lambda_t_ - lambda_s
|
||||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
if eta > 0 and s_noise > 0:
|
||||||
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user