From 1777b54d0217e77a6a64b0a587b9b11a48e3bf02 Mon Sep 17 00:00:00 2001
From: comfyanonymous <comfyanonymous@protonmail.com>
Date: Tue, 31 Oct 2023 17:33:43 -0400
Subject: [PATCH] Sampling code changes.

apply_model in model_base now returns the denoised output.

This means that sampling_function now computes things on the denoised
output instead of the model output. This should make things more consistent
across current and future models.
---
 comfy/extra_samplers/uni_pc.py |   8 ++-
 comfy/model_base.py            | 121 ++++++++++++++++++++++++++-------
 comfy/samplers.py              |  72 +++++++++-----------
 3 files changed, 136 insertions(+), 65 deletions(-)

diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py
index 9d5f0c60..1a7a8392 100644
--- a/comfy/extra_samplers/uni_pc.py
+++ b/comfy/extra_samplers/uni_pc.py
@@ -852,6 +852,12 @@ class SigmaConvert:
         log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
         return log_mean_coeff - log_std
 
+def predict_eps_sigma(model, input, sigma_in, **kwargs):
+    sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
+    input = input * ((sigma ** 2 + 1.0) ** 0.5)
+    return  (input - model(input, sigma_in, **kwargs)) / sigma
+
+
 def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
         timesteps = sigmas.clone()
         if sigmas[-1] == 0:
@@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
         model_type = "noise"
 
         model_fn = model_wrapper(
-            model.predict_eps_sigma,
+            lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
             ns,
             model_type=model_type,
             guidance_type="uncond",
diff --git a/comfy/model_base.py b/comfy/model_base.py
index ea3ea61f..b8d04a2c 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -13,25 +13,31 @@ class ModelType(Enum):
     EPS = 1
     V_PREDICTION = 2
 
-class BaseModel(torch.nn.Module):
-    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
+
+#NOTE: all this sampling stuff will be moved
+class EPS:
+    def calculate_input(self, sigma, noise):
+        sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
+        return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+
+    def calculate_denoised(self, sigma, model_output, model_input):
+        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+        return model_input - model_output * sigma
+
+
+class V_PREDICTION(EPS):
+    def calculate_denoised(self, sigma, model_output, model_input):
+        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+        return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+
+
+class ModelSamplingDiscrete(torch.nn.Module):
+    def __init__(self, model_config):
         super().__init__()
+        self._register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
+        self.sigma_data = 1.0
 
-        unet_config = model_config.unet_config
-        self.latent_format = model_config.latent_format
-        self.model_config = model_config
-        self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
-        if not unet_config.get("disable_unet_model_creation", False):
-            self.diffusion_model = UNetModel(**unet_config, device=device)
-        self.model_type = model_type
-        self.adm_channels = unet_config.get("adm_in_channels", None)
-        if self.adm_channels is None:
-            self.adm_channels = 0
-        self.inpaint_model = False
-        print("model_type", model_type.name)
-        print("adm", self.adm_channels)
-
-    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+    def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                           linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
         if given_betas is not None:
             betas = given_betas
@@ -39,31 +45,94 @@ class BaseModel(torch.nn.Module):
             betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
         alphas = 1. - betas
         alphas_cumprod = np.cumprod(alphas, axis=0)
-        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+        # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
 
         timesteps, = betas.shape
         self.num_timesteps = int(timesteps)
         self.linear_start = linear_start
         self.linear_end = linear_end
 
-        self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
-        self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
-        self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
+        # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
+        # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
+        # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
+
+        sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
+
+        self.register_buffer('sigmas', sigmas)
+        self.register_buffer('log_sigmas', sigmas.log())
+
+    @property
+    def sigma_min(self):
+        return self.sigmas[0]
+
+    @property
+    def sigma_max(self):
+        return self.sigmas[-1]
+
+    def timestep(self, sigma):
+        log_sigma = sigma.log()
+        dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
+        return dists.abs().argmin(dim=0).view(sigma.shape)
+
+    def sigma(self, timestep):
+        t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
+        low_idx = t.floor().long()
+        high_idx = t.ceil().long()
+        w = t.frac()
+        log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
+        return log_sigma.exp()
+
+def model_sampling(model_config, model_type):
+    if model_type == ModelType.EPS:
+        c = EPS
+    elif model_type == ModelType.V_PREDICTION:
+        c = V_PREDICTION
+
+    s = ModelSamplingDiscrete
+
+    class ModelSampling(s, c):
+        pass
+
+    return ModelSampling(model_config)
+
+
+
+class BaseModel(torch.nn.Module):
+    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
+        super().__init__()
+
+        unet_config = model_config.unet_config
+        self.latent_format = model_config.latent_format
+        self.model_config = model_config
+
+        if not unet_config.get("disable_unet_model_creation", False):
+            self.diffusion_model = UNetModel(**unet_config, device=device)
+        self.model_type = model_type
+        self.model_sampling = model_sampling(model_config, model_type)
+
+        self.adm_channels = unet_config.get("adm_in_channels", None)
+        if self.adm_channels is None:
+            self.adm_channels = 0
+        self.inpaint_model = False
+        print("model_type", model_type.name)
+        print("adm", self.adm_channels)
 
     def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
+        sigma = t
+        xc = self.model_sampling.calculate_input(sigma, x)
         if c_concat is not None:
-            xc = torch.cat([x] + [c_concat], dim=1)
-        else:
-            xc = x
+            xc = torch.cat([xc] + [c_concat], dim=1)
+
         context = c_crossattn
         dtype = self.get_dtype()
         xc = xc.to(dtype)
-        t = t.to(dtype)
+        t = self.model_sampling.timestep(t).to(dtype)
         context = context.to(dtype)
         extra_conds = {}
         for o in kwargs:
             extra_conds[o] = kwargs[o].to(dtype)
-        return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
+        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
+        return self.model_sampling.calculate_denoised(sigma, model_output, x)
 
     def get_dtype(self):
         return self.diffusion_model.dtype
diff --git a/comfy/samplers.py b/comfy/samplers.py
index f930aa39..5f9c7455 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -13,7 +13,7 @@ import comfy.conds
 
 
 #The main sampling function shared by all the samplers
-#Returns predicted noise
+#Returns denoised
 def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
         def get_area_and_mult(conds, x_in, timestep_in):
             area = (x_in.shape[2], x_in.shape[3], 0, 0)
@@ -257,24 +257,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
         else:
             return uncond + (cond - uncond) * cond_scale
 
-
-class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
-    def __init__(self, model, quantize=False, device='cpu'):
-        super().__init__(model, model.alphas_cumprod, quantize=quantize)
-
-    def get_v(self, x, t, cond, **kwargs):
-        return self.inner_model.apply_model(x, t, cond, **kwargs)
-
-
 class CFGNoisePredictor(torch.nn.Module):
     def __init__(self, model):
         super().__init__()
         self.inner_model = model
-        self.alphas_cumprod = model.alphas_cumprod
     def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
         out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
         return out
-
+    def forward(self, *args, **kwargs):
+        return self.apply_model(*args, **kwargs)
 
 class KSamplerX0Inpaint(torch.nn.Module):
     def __init__(self, model):
@@ -293,32 +284,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
         return out
 
 def simple_scheduler(model, steps):
+    s = model.model_sampling
     sigs = []
-    ss = len(model.sigmas) / steps
+    ss = len(s.sigmas) / steps
     for x in range(steps):
-        sigs += [float(model.sigmas[-(1 + int(x * ss))])]
+        sigs += [float(s.sigmas[-(1 + int(x * ss))])]
     sigs += [0.0]
     return torch.FloatTensor(sigs)
 
 def ddim_scheduler(model, steps):
+    s = model.model_sampling
     sigs = []
-    ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
-    for x in range(len(ddim_timesteps) - 1, -1, -1):
-        ts = ddim_timesteps[x]
-        if ts > 999:
-            ts = 999
-        sigs.append(model.t_to_sigma(torch.tensor(ts)))
+    ss = len(s.sigmas) // steps
+    x = 1
+    while x < len(s.sigmas):
+        sigs += [float(s.sigmas[x])]
+        x += ss
+    sigs = sigs[::-1]
     sigs += [0.0]
     return torch.FloatTensor(sigs)
 
-def sgm_scheduler(model, steps):
+def normal_scheduler(model, steps, sgm=False, floor=False):
+    s = model.model_sampling
+    start = s.timestep(s.sigma_max)
+    end = s.timestep(s.sigma_min)
+
+    if sgm:
+        timesteps = torch.linspace(start, end, steps + 1)[:-1]
+    else:
+        timesteps = torch.linspace(start, end, steps)
+
     sigs = []
-    timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
     for x in range(len(timesteps)):
         ts = timesteps[x]
-        if ts > 999:
-            ts = 999
-        sigs.append(model.t_to_sigma(torch.tensor(ts)))
+        sigs.append(s.sigma(ts))
     sigs += [0.0]
     return torch.FloatTensor(sigs)
 
@@ -508,7 +507,9 @@ class Sampler:
         pass
 
     def max_denoise(self, model_wrap, sigmas):
-        return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
+        max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
+        sigma = float(sigmas[0])
+        return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
 
 class DDIM(Sampler):
     def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
@@ -592,11 +593,7 @@ def ksampler(sampler_name, extra_options={}):
 
 def wrap_model(model):
     model_denoise = CFGNoisePredictor(model)
-    if model.model_type == model_base.ModelType.V_PREDICTION:
-        model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
-    else:
-        model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
-    return model_wrap
+    return model_denoise
 
 def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
     positive = positive[:]
@@ -637,19 +634,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
 SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
 
 def calculate_sigmas_scheduler(model, scheduler_name, steps):
-    model_wrap = wrap_model(model)
     if scheduler_name == "karras":
-        sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
+        sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
     elif scheduler_name == "exponential":
-        sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
+        sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
     elif scheduler_name == "normal":
-        sigmas = model_wrap.get_sigmas(steps)
+        sigmas = normal_scheduler(model, steps)
     elif scheduler_name == "simple":
-        sigmas = simple_scheduler(model_wrap, steps)
+        sigmas = simple_scheduler(model, steps)
     elif scheduler_name == "ddim_uniform":
-        sigmas = ddim_scheduler(model_wrap, steps)
+        sigmas = ddim_scheduler(model, steps)
     elif scheduler_name == "sgm_uniform":
-        sigmas = sgm_scheduler(model_wrap, steps)
+        sigmas = normal_scheduler(model, steps, sgm=True)
     else:
         print("error invalid scheduler", self.scheduler)
     return sigmas