Fix res_multistep_ancestral sampler (#8030)

This commit is contained in:
Pam 2025-05-10 05:14:13 +05:00 committed by GitHub
parent ae60b150e5
commit 1b3bf0a5da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
phi1_fn = lambda t: torch.expm1(t) / t phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_sigma_down = None
old_denoised = None old_denoised = None
uncond_denoised = None uncond_denoised = None
def post_cfg_function(args): def post_cfg_function(args):
@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
x = x + d * dt x = x + d * dt
else: else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157 # Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1]) t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t h = t_next - t
c2 = (t_prev - t) / h c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = uncond_denoised old_denoised = uncond_denoised
else: else:
old_denoised = denoised old_denoised = denoised
old_sigma_down = sigma_down
return x return x
@torch.no_grad() @torch.no_grad()