From 2a18e98ccf083f7e8d54ef712610aa31adb570d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Nov 2024 04:55:56 -0500 Subject: [PATCH] Refactor so that zsnr can be set in the sampling_settings. --- comfy/model_sampling.py | 31 +++++++++++++++++++++++++--- comfy_extras/nodes_model_advanced.py | 23 +-------------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 4a0f2db6..8b4e095d 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -2,6 +2,25 @@ import torch from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule import math +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + class EPS: def calculate_input(self, sigma, noise): sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) @@ -48,7 +67,7 @@ class CONST: return latent / (1.0 - sigma) class ModelSamplingDiscrete(torch.nn.Module): - def __init__(self, model_config=None): + def __init__(self, model_config=None, zsnr=None): super().__init__() if model_config is not None: @@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module): linear_end = sampling_settings.get("linear_end", 0.012) timesteps = sampling_settings.get("timesteps", 1000) - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) + if zsnr is None: + zsnr = sampling_settings.get("zsnr", False) + + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr) self.sigma_data = 1.0 def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False): if given_betas is not None: betas = given_betas else: @@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module): # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + if zsnr: + sigmas = rescale_zero_terminal_snr_sigmas(sigmas) + self.set_sigmas(sigmas) def set_sigmas(self, sigmas): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 918e6085..ed14b61a 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete) return log_sigma.exp().to(timestep.device) -def rescale_zero_terminal_snr_sigmas(sigmas): - alphas_cumprod = 1 / ((sigmas * sigmas) + 1) - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= (alphas_bar_sqrt_T) - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas_bar[-1] = 4.8973451890853435e-08 - return ((1 - alphas_bar) / alphas_bar) ** 0.5 - class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): @@ -100,9 +81,7 @@ class ModelSamplingDiscrete: class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced(model.model.model_config) - if zsnr: - model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr) m.add_object_patch("model_sampling", model_sampling) return (m, )