mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Implement shift schedule for cascade stage C.
This commit is contained in:
parent
929e266f3e
commit
5b40e7a5ed
@ -136,9 +136,16 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.num_timesteps = 1000
|
||||
self.shift = sampling_settings.get("shift", 1.0)
|
||||
cosine_s=8e-3
|
||||
self.cosine_s = torch.tensor([cosine_s])
|
||||
self.cosine_s = torch.tensor(cosine_s)
|
||||
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
||||
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
||||
for x in range(self.num_timesteps):
|
||||
@ -148,11 +155,23 @@ class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def sigma(self, timestep):
|
||||
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999)
|
||||
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
|
||||
|
||||
if self.shift != 1.0:
|
||||
var = alpha_cumprod
|
||||
logSNR = (var/(1-var)).log()
|
||||
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
|
||||
alpha_cumprod = logSNR.sigmoid()
|
||||
|
||||
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
|
||||
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
||||
|
||||
def timestep(self, sigma):
|
||||
return super().timestep(sigma) / 1000.0
|
||||
var = 1 / ((sigma * sigma) + 1)
|
||||
var = var.clamp(0, 1.0)
|
||||
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
|
||||
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return t
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
|
@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
||||
latent_format = latent_formats.SC_Prior
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
key_list = list(state_dict.keys())
|
||||
for y in ["weight", "bias"]:
|
||||
@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
||||
latent_format = latent_formats.SC_B
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
Loading…
Reference in New Issue
Block a user