diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index cb4bece5..9cc81d1f 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -135,8 +135,7 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_ gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise - x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part - return x_t + return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma): """ @@ -172,8 +171,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise - x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part - return x_t + return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part # Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 4adc828f..85617b68 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1211,12 +1211,10 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if sigmas[-1] == 0: # Denoising step - x = model_prev_list[-1] - else: - x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, + return model_prev_list[-1] + return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, noise=0, sigma=sigmas[-1]) - return x @torch.no_grad() def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):