mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Fix uni_pc sampler math. This changes the images this sampler produces.
This commit is contained in:
parent
f1062be622
commit
4185324a1d
@ -713,8 +713,8 @@ class UniPC:
|
|||||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||||
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
||||||
):
|
):
|
||||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
# t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
device = x.device
|
device = x.device
|
||||||
steps = len(timesteps) - 1
|
steps = len(timesteps) - 1
|
||||||
if method == 'multistep':
|
if method == 'multistep':
|
||||||
@ -769,8 +769,8 @@ class UniPC:
|
|||||||
callback(step_index, model_prev_list[-1], x, steps)
|
callback(step_index, model_prev_list[-1], x, steps)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
# if denoise_to_zero:
|
||||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -833,21 +833,33 @@ def expand_dims(v, dims):
|
|||||||
return v[(...,) + (None,)*(dims - 1)]
|
return v[(...,) + (None,)*(dims - 1)]
|
||||||
|
|
||||||
|
|
||||||
|
class SigmaConvert:
|
||||||
|
schedule = ""
|
||||||
|
def marginal_log_mean_coeff(self, sigma):
|
||||||
|
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
|
||||||
|
|
||||||
|
def marginal_alpha(self, t):
|
||||||
|
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||||
|
|
||||||
|
def marginal_std(self, t):
|
||||||
|
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||||
|
|
||||||
|
def marginal_lambda(self, t):
|
||||||
|
"""
|
||||||
|
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||||
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||||
|
return log_mean_coeff - log_std
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||||
to_zero = False
|
timesteps = sigmas.clone()
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
timesteps = sigmas[:]
|
||||||
to_zero = True
|
timesteps[-1] = 0.001
|
||||||
else:
|
else:
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
|
ns = SigmaConvert()
|
||||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
|
||||||
|
|
||||||
for s in range(timesteps.shape[0]):
|
|
||||||
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
|
||||||
|
|
||||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
img = image * ns.marginal_alpha(timesteps[0])
|
img = image * ns.marginal_alpha(timesteps[0])
|
||||||
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
else:
|
else:
|
||||||
img = noise
|
img = noise
|
||||||
|
|
||||||
if to_zero:
|
|
||||||
timesteps[-1] = (1 / len(alphas_cumprod))
|
|
||||||
|
|
||||||
device = noise.device
|
|
||||||
|
|
||||||
|
|
||||||
model_type = "noise"
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.predict_eps_discrete_timestep,
|
model.predict_eps_sigma,
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
order = min(3, len(timesteps) - 1)
|
order = min(3, len(timesteps) - 1)
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||||
if not to_zero:
|
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
return x
|
return x
|
||||||
|
@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module):
|
|||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||||
|
|
||||||
|
def predict_eps_sigma(self, input, sigma, **kwargs):
|
||||||
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||||
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||||
|
|
||||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||||
noise)."""
|
noise)."""
|
||||||
|
@ -739,7 +739,7 @@ class KSampler:
|
|||||||
sigmas = None
|
sigmas = None
|
||||||
|
|
||||||
discard_penultimate_sigma = False
|
discard_penultimate_sigma = False
|
||||||
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
|
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
|
||||||
steps += 1
|
steps += 1
|
||||||
discard_penultimate_sigma = True
|
discard_penultimate_sigma = True
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user