mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
change to direct return
This commit is contained in:
parent
6b68b61644
commit
8a8327fa73
@ -135,8 +135,7 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_
|
||||
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
||||
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
||||
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
||||
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||
return x_t
|
||||
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||
|
||||
def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
|
||||
"""
|
||||
@ -172,8 +171,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
|
||||
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
||||
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
||||
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
||||
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||
return x_t
|
||||
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||
|
||||
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
||||
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
||||
|
@ -1211,12 +1211,10 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
|
||||
if sigmas[-1] == 0:
|
||||
# Denoising step
|
||||
x = model_prev_list[-1]
|
||||
else:
|
||||
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
|
||||
return model_prev_list[-1]
|
||||
return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
|
||||
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
|
||||
noise=0, sigma=sigmas[-1])
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
||||
|
Loading…
Reference in New Issue
Block a user