From 3ded1a3a04f5e9b0d916ed8fb3354391f0689d30 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Jul 2023 01:22:12 -0400 Subject: [PATCH] Refactor of sampler code to deal more easily with different model types. --- comfy/extra_samplers/uni_pc.py | 22 ++++++++++---------- comfy/k_diffusion/external.py | 15 +++++++++++--- comfy/ldm/models/diffusion/ddim.py | 9 +++++---- comfy/model_base.py | 32 +++++++++++++++--------------- comfy/samplers.py | 9 +++++---- comfy/sd.py | 10 +++++----- comfy/supported_models.py | 14 +++++++++---- comfy/supported_models_base.py | 10 +++++----- 8 files changed, 68 insertions(+), 53 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 2ff10caf1..7eaf6ff62 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -180,7 +180,6 @@ class NoiseScheduleVP: def model_wrapper( model, - sampling_function, noise_schedule, model_type="noise", model_kwargs={}, @@ -295,7 +294,7 @@ def model_wrapper( if t_continuous.reshape((-1,)).shape[0] == 1: t_continuous = t_continuous.expand((x.shape[0])) t_input = get_model_input_time(t_continuous) - output = sampling_function(model, x, t_input, **model_kwargs) + output = model(x, t_input, **model_kwargs) if model_type == "noise": return output elif model_type == "x_start": @@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex else: timesteps = sigmas.clone() - for s in range(timesteps.shape[0]): - timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas)) + alphas_cumprod = model.inner_model.alphas_cumprod - ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod) + for s in range(timesteps.shape[0]): + timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod)) + + ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) if image is not None: img = image * ns.marginal_alpha(timesteps[0]) @@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex img = noise if to_zero: - timesteps[-1] = (1 / len(model.sigmas)) + timesteps[-1] = (1 / len(alphas_cumprod)) device = noise.device - if model.parameterization == "v": - model_type = "v" - else: - model_type = "noise" + + model_type = "noise" model_fn = model_wrapper( - model.inner_model.inner_model.apply_model, - sampling_function, + model.predict_eps_discrete_timestep, ns, model_type=model_type, guidance_type="uncond", diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index 49ce5ae39..680a3568c 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module): t = torch.linspace(t_max, 0, n, device=self.sigmas.device) return sampling.append_zero(self.t_to_sigma(t)) - def sigma_to_t(self, sigma, quantize=None): - quantize = self.quantize if quantize is None else quantize + def sigma_to_discrete_timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize if quantize: - return dists.abs().argmin(dim=0).view(sigma.shape) + return self.sigma_to_discrete_timestep(sigma) + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] @@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp() + def predict_eps_discrete_timestep(self, input, t, **kwargs): + sigma = self.t_to_sigma(t.round()) + input = input * ((sigma ** 2 + 1.0) ** 0.5) + return (input - self(input, sigma, **kwargs)) / sigma class DiscreteEpsDDPMDenoiser(DiscreteSchedule): """A wrapper for discrete schedule DDPM models that output eps (the predicted diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 108fce1cf..139c8e01e 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -14,6 +14,7 @@ class DDIMSampler(object): self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule self.device = device + self.parameterization = kwargs.get("parameterization", "eps") def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -261,7 +262,7 @@ class DDIMSampler(object): b, *_, device = *x.shape, x.device if denoise_function is not None: - model_output = denoise_function(self.model.apply_model, x, t, **extra_args) + model_output = denoise_function(x, t, **extra_args) elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output = self.model.apply_model(x, t, c) else: @@ -289,13 +290,13 @@ class DDIMSampler(object): model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - if self.model.parameterization == "v": + if self.parameterization == "v": e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x else: e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' + assert self.parameterization == "eps", 'not implemented' e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas @@ -309,7 +310,7 @@ class DDIMSampler(object): sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) # current prediction for x_0 - if self.model.parameterization != "v": + if self.parameterization != "v": pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() else: pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output diff --git a/comfy/model_base.py b/comfy/model_base.py index 9197dc4b9..c73f2aa07 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np +from enum import Enum from . import utils +class ModelType(Enum): + EPS = 1 + V_PREDICTION = 2 + class BaseModel(torch.nn.Module): - def __init__(self, model_config, v_prediction=False): + def __init__(self, model_config, model_type=ModelType.EPS): super().__init__() unet_config = model_config.unet_config @@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module): self.model_config = model_config self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) - self.v_prediction = v_prediction - if self.v_prediction: - self.parameterization = "v" - else: - self.parameterization = "eps" - + self.model_type = model_type self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 - print("v_prediction", v_prediction) + print("model_type", model_type.name) print("adm", self.adm_channels) def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, @@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module): class SD21UNCLIP(BaseModel): - def __init__(self, model_config, noise_aug_config, v_prediction=True): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION): + super().__init__(model_config, model_type) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) def encode_adm(self, **kwargs): @@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel): return adm_out class SDInpaint(BaseModel): - def __init__(self, model_config, v_prediction=False): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, model_type=ModelType.EPS): + super().__init__(model_config, model_type) self.concat_keys = ("mask", "masked_image") class SDXLRefiner(BaseModel): - def __init__(self, model_config, v_prediction=False): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, model_type=ModelType.EPS): + super().__init__(model_config, model_type) self.embedder = Timestep(256) def encode_adm(self, **kwargs): @@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel): return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): - def __init__(self, model_config, v_prediction=False): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, model_type=ModelType.EPS): + super().__init__(model_config, model_type) self.embedder = Timestep(256) def encode_adm(self, **kwargs): diff --git a/comfy/samplers.py b/comfy/samplers.py index 81d1facd8..50fda016d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math +from comfy import model_base def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) @@ -488,11 +489,11 @@ class KSampler: def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model self.model_denoise = CFGNoisePredictor(self.model) - if self.model.parameterization == "v": + if self.model.model_type == model_base.ModelType.V_PREDICTION: self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) else: self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) - self.model_wrap.parameterization = self.model.parameterization + self.model_k = KSamplerX0Inpaint(self.model_wrap) self.device = device if scheduler not in self.SCHEDULERS: @@ -614,7 +615,7 @@ class KSampler: elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): - timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s])) + timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s])) noise_mask = None if denoise_mask is not None: noise_mask = 1.0 - denoise_mask @@ -638,7 +639,7 @@ class KSampler: x_T=z_enc, x0=latent_image, img_callback=ddim_callback, - denoise_function=sampling_function, + denoise_function=self.model_wrap.predict_eps_discrete_timestep, extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, diff --git a/comfy/sd.py b/comfy/sd.py index 9a96dbe8c..a7887a82b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if "noise_aug_config" in model_config_params: noise_aug_config = model_config_params["noise_aug_config"] - v_prediction = False + model_type = model_base.ModelType.EPS if "parameterization" in model_config_params: if model_config_params["parameterization"] == "v": - v_prediction = True + model_type = model_base.ModelType.V_PREDICTION clip = None vae = None @@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) if config['model']["target"].endswith("LatentInpaintDiffusion"): - model = model_base.SDInpaint(model_config, v_prediction=v_prediction) + model = model_base.SDInpaint(model_config, model_type=model_type) elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): - model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction) + model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) else: - model = model_base.BaseModel(model_config, v_prediction=v_prediction) + model = model_base.BaseModel(model_config, model_type=model_type) if fp16: model = model.half() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b7fdfe9fe..915214081 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE): latent_format = latent_formats.SD15 - def v_prediction(self, state_dict, prefix=""): + def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) out = state_dict[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. - return True - return False + return model_base.ModelType.V_PREDICTION + return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) @@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL + def model_type(self, state_dict, prefix=""): + if "v_pred" in state_dict: + return model_base.ModelType.V_PREDICTION + else: + return model_base.ModelType.EPS + def get_model(self, state_dict, prefix=""): - return model_base.SDXL(self) + return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix)) def process_clip_state_dict(self, state_dict): keys_to_replace = {} diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 86dc67068..c5db66274 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -41,8 +41,8 @@ class BASE: return False return True - def v_prediction(self, state_dict, prefix=""): - return False + def model_type(self, state_dict, prefix=""): + return model_base.ModelType.EPS def inpaint_model(self): return self.unet_config["in_channels"] > 4 @@ -55,11 +55,11 @@ class BASE: def get_model(self, state_dict, prefix=""): if self.inpaint_model(): - return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix)) + return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix)) elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix)) + return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix)) else: - return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix)) + return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix)) def process_clip_state_dict(self, state_dict): return state_dict