From 3a150bad1590d6e23cfcf3e621d575dba1ca8c2a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Jul 2023 10:11:08 -0400 Subject: [PATCH] Only calculate randn in some samplers when it's actually being used. --- comfy/k_diffusion/sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 4cc2534f3..3b4e99315 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -131,9 +131,9 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: + eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) @@ -172,9 +172,9 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: + eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) @@ -201,9 +201,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: + eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised)