Refactor LCM to support more model types.

This commit is contained in:
comfyanonymous 2023-12-15 15:26:12 -05:00
parent 9cad2f06ff
commit 014c8bf2f2

View File

@ -17,41 +17,19 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module):
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
def __init__(self, model_config=None):
super().__init__(model_config)
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = self.num_timesteps // self.original_timesteps
self.skip_steps = timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
self.set_sigmas(sigmas_valid)
def timestep(self, sigma):
log_sigma = sigma.log()
@ -66,14 +44,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@ -154,7 +124,7 @@ class ModelSamplingContinuousEDM:
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )