mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
change to direct return
This commit is contained in:
parent
6b68b61644
commit
8a8327fa73
@ -135,8 +135,7 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_
|
|||||||
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
||||||
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
||||||
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
||||||
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||||
return x_t
|
|
||||||
|
|
||||||
def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
|
def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
|
||||||
"""
|
"""
|
||||||
@ -172,8 +171,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
|
|||||||
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
||||||
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
||||||
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
||||||
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||||
return x_t
|
|
||||||
|
|
||||||
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
||||||
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
||||||
|
@ -1211,12 +1211,10 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
|
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
# Denoising step
|
# Denoising step
|
||||||
x = model_prev_list[-1]
|
return model_prev_list[-1]
|
||||||
else:
|
return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
|
||||||
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
|
|
||||||
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
|
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
|
||||||
noise=0, sigma=sigmas[-1])
|
noise=0, sigma=sigmas[-1])
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user