change to direct return

This commit is contained in:
chaObserv 2025-01-06 18:58:27 +08:00
parent 6b68b61644
commit 8a8327fa73
2 changed files with 4 additions and 8 deletions

View File

@ -135,8 +135,7 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
return x_t
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
"""
@ -172,8 +171,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
return x_t
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):

View File

@ -1211,12 +1211,10 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if sigmas[-1] == 0:
# Denoising step
x = model_prev_list[-1]
else:
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
return model_prev_list[-1]
return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
return x
@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):