mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Sampling code changes.
apply_model in model_base now returns the denoised output. This means that sampling_function now computes things on the denoised output instead of the model output. This should make things more consistent across current and future models.
This commit is contained in:
parent
c837a173fa
commit
1777b54d02
comfy
@ -852,6 +852,12 @@ class SigmaConvert:
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
||||
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
timesteps = sigmas.clone()
|
||||
if sigmas[-1] == 0:
|
||||
@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.predict_eps_sigma,
|
||||
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
||||
ns,
|
||||
model_type=model_type,
|
||||
guidance_type="uncond",
|
||||
|
@ -13,25 +13,31 @@ class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
|
||||
#NOTE: all this sampling stuff will be moved
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config):
|
||||
super().__init__()
|
||||
self._register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
self.inpaint_model = False
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
@ -39,31 +45,94 @@ class BaseModel(torch.nn.Module):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
self.inpaint_model = False
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + [c_concat], dim=1)
|
||||
else:
|
||||
xc = x
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
t = self.model_sampling.timestep(t).to(dtype)
|
||||
context = context.to(dtype)
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra_conds[o] = kwargs[o].to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
@ -13,7 +13,7 @@ import comfy.conds
|
||||
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
#Returns denoised
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
@ -257,24 +257,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@ -293,32 +284,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ss = len(model.sigmas) / steps
|
||||
ss = len(s.sigmas) / steps
|
||||
for x in range(steps):
|
||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
||||
ts = ddim_timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
ss = len(s.sigmas) // steps
|
||||
x = 1
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
x += ss
|
||||
sigs = sigs[::-1]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def sgm_scheduler(model, steps):
|
||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||
s = model.model_sampling
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
if sgm:
|
||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||
else:
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
|
||||
sigs = []
|
||||
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
sigs.append(s.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
@ -508,7 +507,9 @@ class Sampler:
|
||||
pass
|
||||
|
||||
def max_denoise(self, model_wrap, sigmas):
|
||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
|
||||
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
class DDIM(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
@ -592,11 +593,7 @@ def ksampler(sampler_name, extra_options={}):
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
||||
else:
|
||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
||||
return model_wrap
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
@ -637,19 +634,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
model_wrap = wrap_model(model)
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = model_wrap.get_sigmas(steps)
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model_wrap, steps)
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model_wrap, steps)
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = sgm_scheduler(model_wrap, steps)
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
return sigmas
|
||||
|
Loading…
Reference in New Issue
Block a user