From 1777b54d0217e77a6a64b0a587b9b11a48e3bf02 Mon Sep 17 00:00:00 2001 From: comfyanonymous <comfyanonymous@protonmail.com> Date: Tue, 31 Oct 2023 17:33:43 -0400 Subject: [PATCH] Sampling code changes. apply_model in model_base now returns the denoised output. This means that sampling_function now computes things on the denoised output instead of the model output. This should make things more consistent across current and future models. --- comfy/extra_samplers/uni_pc.py | 8 ++- comfy/model_base.py | 121 ++++++++++++++++++++++++++------- comfy/samplers.py | 72 +++++++++----------- 3 files changed, 136 insertions(+), 65 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 9d5f0c60..1a7a8392 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -852,6 +852,12 @@ class SigmaConvert: log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) return log_mean_coeff - log_std +def predict_eps_sigma(model, input, sigma_in, **kwargs): + sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1)) + input = input * ((sigma ** 2 + 1.0) ** 0.5) + return (input - model(input, sigma_in, **kwargs)) / sigma + + def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: @@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex model_type = "noise" model_fn = model_wrapper( - model.predict_eps_sigma, + lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs), ns, model_type=model_type, guidance_type="uncond", diff --git a/comfy/model_base.py b/comfy/model_base.py index ea3ea61f..b8d04a2c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -13,25 +13,31 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 -class BaseModel(torch.nn.Module): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + +#NOTE: all this sampling stuff will be moved +class EPS: + def calculate_input(self, sigma, noise): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + +class V_PREDICTION(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + +class ModelSamplingDiscrete(torch.nn.Module): + def __init__(self, model_config): super().__init__() + self._register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.sigma_data = 1.0 - unet_config = model_config.unet_config - self.latent_format = model_config.latent_format - self.model_config = model_config - self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) - if not unet_config.get("disable_unet_model_creation", False): - self.diffusion_model = UNetModel(**unet_config, device=device) - self.model_type = model_type - self.adm_channels = unet_config.get("adm_in_channels", None) - if self.adm_channels is None: - self.adm_channels = 0 - self.inpaint_model = False - print("model_type", model_type.name) - print("adm", self.adm_channels) - - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if given_betas is not None: betas = given_betas @@ -39,31 +45,94 @@ class BaseModel(torch.nn.Module): betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) - self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) - self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) + + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def sigma(self, timestep): + t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + +def model_sampling(model_config, model_type): + if model_type == ModelType.EPS: + c = EPS + elif model_type == ModelType.V_PREDICTION: + c = V_PREDICTION + + s = ModelSamplingDiscrete + + class ModelSampling(s, c): + pass + + return ModelSampling(model_config) + + + +class BaseModel(torch.nn.Module): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__() + + unet_config = model_config.unet_config + self.latent_format = model_config.latent_format + self.model_config = model_config + + if not unet_config.get("disable_unet_model_creation", False): + self.diffusion_model = UNetModel(**unet_config, device=device) + self.model_type = model_type + self.model_sampling = model_sampling(model_config, model_type) + + self.adm_channels = unet_config.get("adm_in_channels", None) + if self.adm_channels is None: + self.adm_channels = 0 + self.inpaint_model = False + print("model_type", model_type.name) + print("adm", self.adm_channels) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + sigma = t + xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([x] + [c_concat], dim=1) - else: - xc = x + xc = torch.cat([xc] + [c_concat], dim=1) + context = c_crossattn dtype = self.get_dtype() xc = xc.to(dtype) - t = t.to(dtype) + t = self.model_sampling.timestep(t).to(dtype) context = context.to(dtype) extra_conds = {} for o in kwargs: extra_conds[o] = kwargs[o].to(dtype) - return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): return self.diffusion_model.dtype diff --git a/comfy/samplers.py b/comfy/samplers.py index f930aa39..5f9c7455 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -13,7 +13,7 @@ import comfy.conds #The main sampling function shared by all the samplers -#Returns predicted noise +#Returns denoised def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) @@ -257,24 +257,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod else: return uncond + (cond - uncond) * cond_scale - -class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_v(self, x, t, cond, **kwargs): - return self.inner_model.apply_model(x, t, cond, **kwargs) - - class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - self.alphas_cumprod = model.alphas_cumprod def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) return out - + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): @@ -293,32 +284,40 @@ class KSamplerX0Inpaint(torch.nn.Module): return out def simple_scheduler(model, steps): + s = model.model_sampling sigs = [] - ss = len(model.sigmas) / steps + ss = len(s.sigmas) / steps for x in range(steps): - sigs += [float(model.sigmas[-(1 + int(x * ss))])] + sigs += [float(s.sigmas[-(1 + int(x * ss))])] sigs += [0.0] return torch.FloatTensor(sigs) def ddim_scheduler(model, steps): + s = model.model_sampling sigs = [] - ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False) - for x in range(len(ddim_timesteps) - 1, -1, -1): - ts = ddim_timesteps[x] - if ts > 999: - ts = 999 - sigs.append(model.t_to_sigma(torch.tensor(ts))) + ss = len(s.sigmas) // steps + x = 1 + while x < len(s.sigmas): + sigs += [float(s.sigmas[x])] + x += ss + sigs = sigs[::-1] sigs += [0.0] return torch.FloatTensor(sigs) -def sgm_scheduler(model, steps): +def normal_scheduler(model, steps, sgm=False, floor=False): + s = model.model_sampling + start = s.timestep(s.sigma_max) + end = s.timestep(s.sigma_min) + + if sgm: + timesteps = torch.linspace(start, end, steps + 1)[:-1] + else: + timesteps = torch.linspace(start, end, steps) + sigs = [] - timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int) for x in range(len(timesteps)): ts = timesteps[x] - if ts > 999: - ts = 999 - sigs.append(model.t_to_sigma(torch.tensor(ts))) + sigs.append(s.sigma(ts)) sigs += [0.0] return torch.FloatTensor(sigs) @@ -508,7 +507,9 @@ class Sampler: pass def max_denoise(self, model_wrap, sigmas): - return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05) + max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma class DDIM(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): @@ -592,11 +593,7 @@ def ksampler(sampler_name, extra_options={}): def wrap_model(model): model_denoise = CFGNoisePredictor(model) - if model.model_type == model_base.ModelType.V_PREDICTION: - model_wrap = CompVisVDenoiser(model_denoise, quantize=True) - else: - model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True) - return model_wrap + return model_denoise def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): positive = positive[:] @@ -637,19 +634,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", " SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] def calculate_sigmas_scheduler(model, scheduler_name, steps): - model_wrap = wrap_model(model) if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = model_wrap.get_sigmas(steps) + sigmas = normal_scheduler(model, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model_wrap, steps) + sigmas = simple_scheduler(model, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model_wrap, steps) + sigmas = ddim_scheduler(model, steps) elif scheduler_name == "sgm_uniform": - sigmas = sgm_scheduler(model_wrap, steps) + sigmas = normal_scheduler(model, steps, sgm=True) else: print("error invalid scheduler", self.scheduler) return sigmas