mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
This is cleaner this way.
This commit is contained in:
parent
36acce58e7
commit
7983b3a975
@ -400,38 +400,6 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
|
||||
return conds
|
||||
|
||||
def calculate_sigmas(model, steps, scheduler, sampler):
|
||||
"""
|
||||
Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique
|
||||
"""
|
||||
if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)):
|
||||
model = CFGNoisePredictor(model)
|
||||
if model.inner_model.parameterization == "v":
|
||||
model = CompVisVDenoiser(model, quantize=True)
|
||||
else:
|
||||
model = k_diffusion_external.CompVisDenoiser(model, quantize=True)
|
||||
|
||||
sigmas = None
|
||||
|
||||
discard_penultimate_sigma = False
|
||||
if sampler in ['dpm_2', 'dpm_2_ancestral']:
|
||||
steps += 1
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
if scheduler == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max))
|
||||
elif scheduler == "normal":
|
||||
sigmas = model.get_sigmas(steps)
|
||||
elif scheduler == "simple":
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
elif scheduler == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
else:
|
||||
print("error invalid scheduler", scheduler)
|
||||
|
||||
if discard_penultimate_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
return sigmas
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
@ -461,13 +429,36 @@ class KSampler:
|
||||
self.denoise = denoise
|
||||
self.model_options = model_options
|
||||
|
||||
def calculate_sigmas(self, steps):
|
||||
sigmas = None
|
||||
|
||||
discard_penultimate_sigma = False
|
||||
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
|
||||
steps += 1
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
if self.scheduler == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "normal":
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
elif self.scheduler == "simple":
|
||||
sigmas = simple_scheduler(self.model_wrap, steps)
|
||||
elif self.scheduler == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(self.model_wrap, steps)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
|
||||
if discard_penultimate_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
return sigmas
|
||||
|
||||
def set_steps(self, steps, denoise=None):
|
||||
self.steps = steps
|
||||
if denoise is None or denoise > 0.9999:
|
||||
self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device)
|
||||
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device)
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user