Sampling code changes.

apply_model in model_base now returns the denoised output.

This means that sampling_function now computes things on the denoised
output instead of the model output. This should make things more consistent
across current and future models.
This commit is contained in:
comfyanonymous 2023-10-31 17:33:43 -04:00
parent c837a173fa
commit 1777b54d02
3 changed files with 136 additions and 65 deletions

View File

@ -852,6 +852,12 @@ class SigmaConvert:
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std return log_mean_coeff - log_std
def predict_eps_sigma(model, input, sigma_in, **kwargs):
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
input = input * ((sigma ** 2 + 1.0) ** 0.5)
return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone() timesteps = sigmas.clone()
if sigmas[-1] == 0: if sigmas[-1] == 0:
@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
model_type = "noise" model_type = "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
model.predict_eps_sigma, lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
ns, ns,
model_type=model_type, model_type=model_type,
guidance_type="uncond", guidance_type="uncond",

View File

@ -13,25 +13,31 @@ class ModelType(Enum):
EPS = 1 EPS = 1
V_PREDICTION = 2 V_PREDICTION = 2
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): #NOTE: all this sampling stuff will be moved
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config):
super().__init__() super().__init__()
self._register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0
unet_config = model_config.unet_config def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None: if given_betas is not None:
betas = given_betas betas = given_betas
@ -39,31 +45,94 @@ class BaseModel(torch.nn.Module):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape timesteps, = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def model_sampling(model_config, model_type):
if model_type == ModelType.EPS:
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
s = ModelSamplingDiscrete
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__()
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1) xc = torch.cat([xc] + [c_concat], dim=1)
else:
xc = x
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype()
xc = xc.to(dtype) xc = xc.to(dtype)
t = t.to(dtype) t = self.model_sampling.timestep(t).to(dtype)
context = context.to(dtype) context = context.to(dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:
extra_conds[o] = kwargs[o].to(dtype) extra_conds[o] = kwargs[o].to(dtype)
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self): def get_dtype(self):
return self.diffusion_model.dtype return self.diffusion_model.dtype

View File

@ -13,7 +13,7 @@ import comfy.conds
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns denoised
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
@ -257,24 +257,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
else: else:
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module): class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module): class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
@ -293,32 +284,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model, steps):
s = model.model_sampling
sigs = [] sigs = []
ss = len(model.sigmas) / steps ss = len(s.sigmas) / steps
for x in range(steps): for x in range(steps):
sigs += [float(model.sigmas[-(1 + int(x * ss))])] sigs += [float(s.sigmas[-(1 + int(x * ss))])]
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps): def ddim_scheduler(model, steps):
s = model.model_sampling
sigs = [] sigs = []
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False) ss = len(s.sigmas) // steps
for x in range(len(ddim_timesteps) - 1, -1, -1): x = 1
ts = ddim_timesteps[x] while x < len(s.sigmas):
if ts > 999: sigs += [float(s.sigmas[x])]
ts = 999 x += ss
sigs.append(model.t_to_sigma(torch.tensor(ts))) sigs = sigs[::-1]
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def sgm_scheduler(model, steps): def normal_scheduler(model, steps, sgm=False, floor=False):
s = model.model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
if sgm:
timesteps = torch.linspace(start, end, steps + 1)[:-1]
else:
timesteps = torch.linspace(start, end, steps)
sigs = [] sigs = []
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
for x in range(len(timesteps)): for x in range(len(timesteps)):
ts = timesteps[x] ts = timesteps[x]
if ts > 999: sigs.append(s.sigma(ts))
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
@ -508,7 +507,9 @@ class Sampler:
pass pass
def max_denoise(self, model_wrap, sigmas): def max_denoise(self, model_wrap, sigmas):
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05) max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class DDIM(Sampler): class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
@ -592,11 +593,7 @@ def ksampler(sampler_name, extra_options={}):
def wrap_model(model): def wrap_model(model):
model_denoise = CFGNoisePredictor(model) model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION: return model_denoise
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
return model_wrap
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:] positive = positive[:]
@ -637,19 +634,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps): def calculate_sigmas_scheduler(model, scheduler_name, steps):
model_wrap = wrap_model(model)
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "normal": elif scheduler_name == "normal":
sigmas = model_wrap.get_sigmas(steps) sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple": elif scheduler_name == "simple":
sigmas = simple_scheduler(model_wrap, steps) sigmas = simple_scheduler(model, steps)
elif scheduler_name == "ddim_uniform": elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_wrap, steps) sigmas = ddim_scheduler(model, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = sgm_scheduler(model_wrap, steps) sigmas = normal_scheduler(model, steps, sgm=True)
else: else:
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
return sigmas return sigmas