Rename kl_optimal_schedule to kl_optimal_scheduler to be more consistent

This commit is contained in:
blepping 2024-12-24 10:50:08 -07:00
parent 2992402051
commit 7be71a8142

View File

@ -468,7 +468,7 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
def kl_optimal_schedule(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
sigmas = adj_idxs.new_zeros(n + 1)
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
@ -939,7 +939,7 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
elif scheduler_name == "kl_optimal":
sigmas = kl_optimal_schedule(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
sigmas = kl_optimal_scheduler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas