diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py
index 9cc81d1f..b784ecbf 100644
--- a/comfy/k_diffusion/sa_solver.py
+++ b/comfy/k_diffusion/sa_solver.py
@@ -17,27 +17,27 @@ def get_coefficients_exponential_positive(order, interval_start, interval_end, t
     interval_start_cov = (1 + tau ** 2) * interval_start
 
     if order == 0:
-        return (torch.exp(interval_end_cov) 
+        return (torch.exp(interval_end_cov)
                 * (1 - torch.exp(-(interval_end_cov - interval_start_cov)))
                 / ((1 + tau ** 2))
                 )
     elif order == 1:
-        return (torch.exp(interval_end_cov) 
+        return (torch.exp(interval_end_cov)
                 * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)))
                 / ((1 + tau ** 2) ** 2)
                 )
     elif order == 2:
-        return (torch.exp(interval_end_cov) 
-                * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) 
-                    - (interval_start_cov ** 2 - 2 * interval_start_cov + 2) 
+        return (torch.exp(interval_end_cov)
+                * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2)
+                    - (interval_start_cov ** 2 - 2 * interval_start_cov + 2)
                     * torch.exp(-(interval_end_cov - interval_start_cov))
                   )
                 / ((1 + tau ** 2) ** 3)
                 )
     elif order == 3:
-        return (torch.exp(interval_end_cov) 
+        return (torch.exp(interval_end_cov)
                 * ((interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6)
-                   - (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) 
+                   - (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6)
                    * torch.exp(-(interval_end_cov - interval_start_cov))
                   )
                 / ((1 + tau ** 2) ** 4)
@@ -53,7 +53,7 @@ def lagrange_polynomial_coefficient(order, lambda_list):
     if order == 0:
         return [[1.0]]
     elif order == 1:
-        return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], 
+        return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
                 [1.0 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
     elif order == 2:
         denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
@@ -79,12 +79,12 @@ def lagrange_polynomial_coefficient(order, lambda_list):
                  (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
 
                 [1.0 / denominator3,
-                 (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, 
-                 (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3, 
+                 (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
+                 (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3,
                  (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
 
                 [1.0 / denominator4,
-                 (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, 
+                 (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
                  (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2]) / denominator4,
                  (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
                 ]
@@ -122,11 +122,11 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_
         # ODE case
         # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
         # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
-        gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t) 
-                                     * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2)) 
+        gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
+                                     * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
                                      / (lambda_prev - lambda_list[1])
                                     )
-        gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t) 
+        gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
                                      * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
                                      / (lambda_prev - lambda_list[1])
                                     )
@@ -152,7 +152,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
     lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev)
     h = lambda_t - lambda_prev
     gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)
-    
+
     if order == 2:  ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
         # The added term is O(h^3). Empirically we find it will slightly improve the image quality.
         # ODE case
@@ -166,7 +166,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
                                      * (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
                                      / ((1 + tau ** 2) ** 2 * h))
                                     )
-    
+
     for i in range(order):
         gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
     gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
@@ -178,4 +178,4 @@ def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
     if eta == 0:
         # Pure ODE
         return 0
-    return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0
\ No newline at end of file
+    return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 85617b68..4c8d18f2 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1146,7 +1146,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
 def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, pc_mode="PEC", tau_func=None, noise_sampler=None):
     if len(sigmas) <= 1:
         return x
-    
+
     extra_args = {} if extra_args is None else extra_args
     if tau_func is None:
         model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1172,7 +1172,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
             # Lower order final
             predictor_order_used = min(predictor_order, i, len(sigmas) - i - 1)
             corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
-            
+
             tau_val = tau(sigma)
             noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1])
 
@@ -1183,13 +1183,13 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
 
             # Evaluation step
             denoised = model(x_p, sigma * s_in, **extra_args)
-            model_prev_list.append(denoised) 
+            model_prev_list.append(denoised)
 
             # Corrector step
             if corrector_order_used > 0:
                 x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
                                                              model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
-                                                             noise=noise, sigma=sigma)                                     
+                                                             noise=noise, sigma=sigma)
             else:
                 x = x_p
 
@@ -1205,7 +1205,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
             if len(model_prev_list) > max(predictor_order, corrector_order):
                 del model_prev_list[0]
                 del sigma_prev_list[0]
- 
+
         if callback is not None:
             callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
 
diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py
index 03461d6c..3cffdf8a 100644
--- a/comfy_extras/nodes_custom_sampler.py
+++ b/comfy_extras/nodes_custom_sampler.py
@@ -435,7 +435,7 @@ class SamplerSASolver:
         start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
         end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
         tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
-        
+
         if pc_mode == 'PEC':
             sampler_name = "sa_solver"
         else: