diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 605b1092..06dcc0b1 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1095,10 +1095,6 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if len(sigmas) <= 1: return x - if sigmas[-1] == 0: - sigmas = sigmas.clone() - sigmas[-1] = 0.001 - extra_args = {} if extra_args is None else extra_args if tau_func is None: model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1115,7 +1111,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] if i == 0: - # Init the initial values. + # Init the initial values denoised = model(x, sigma * s_in, **extra_args) model_prev_list.append(denoised) sigma_prev_list.append(sigma) @@ -1134,22 +1130,19 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Evaluation step denoised = model(x_p, sigma * s_in, **extra_args) - - # Update model_list model_prev_list.append(denoised) # Corrector step if corrector_order_used > 0: x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) - + noise=noise, sigma=sigma) else: x = x_p del noise, x_p - # Evaluation step if mode = pece and step != steps + # Evaluation step for PECE if corrector_order_used > 0 and pc_mode == 'PECE': del model_prev_list[-1] denoised = model(x, sigma * s_in, **extra_args) @@ -1163,10 +1156,13 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if callback is not None: callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) - # Extra final step - x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=0, sigma=sigmas[-1]) + if sigmas[-1] == 0: + # Denoising step + x = model_prev_list[-1] + else: + x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=0, sigma=sigmas[-1]) return x @torch.no_grad()