This is cleaner this way.

This commit is contained in:
comfyanonymous 2023-04-24 22:45:35 -04:00
parent 36acce58e7
commit 7983b3a975

View File

@ -400,38 +400,6 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds
def calculate_sigmas(model, steps, scheduler, sampler):
"""
Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique
"""
if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)):
model = CFGNoisePredictor(model)
if model.inner_model.parameterization == "v":
model = CompVisVDenoiser(model, quantize=True)
else:
model = k_diffusion_external.CompVisDenoiser(model, quantize=True)
sigmas = None
discard_penultimate_sigma = False
if sampler in ['dpm_2', 'dpm_2_ancestral']:
steps += 1
discard_penultimate_sigma = True
if scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max))
elif scheduler == "normal":
sigmas = model.get_sigmas(steps)
elif scheduler == "simple":
sigmas = simple_scheduler(model, steps)
elif scheduler == "ddim_uniform":
sigmas = ddim_scheduler(model, steps)
else:
print("error invalid scheduler", scheduler)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
@ -461,13 +429,36 @@ class KSampler:
self.denoise = denoise
self.model_options = model_options
def calculate_sigmas(self, steps):
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
steps += 1
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def set_steps(self, steps, denoise=None):
self.steps = steps
if denoise is None or denoise > 0.9999:
self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device)
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
new_steps = int(steps/denoise)
sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]