mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Cleanup uni_pc inpainting.
This causes some small changes to the uni pc inpainting behavior but it seems to improve results slightly.
This commit is contained in:
parent
877a8f7a3c
commit
10847dfafe
@ -358,9 +358,6 @@ class UniPC:
|
|||||||
thresholding=False,
|
thresholding=False,
|
||||||
max_val=1.,
|
max_val=1.,
|
||||||
variant='bh1',
|
variant='bh1',
|
||||||
noise_mask=None,
|
|
||||||
masked_image=None,
|
|
||||||
noise=None,
|
|
||||||
):
|
):
|
||||||
"""Construct a UniPC.
|
"""Construct a UniPC.
|
||||||
|
|
||||||
@ -372,9 +369,6 @@ class UniPC:
|
|||||||
self.predict_x0 = predict_x0
|
self.predict_x0 = predict_x0
|
||||||
self.thresholding = thresholding
|
self.thresholding = thresholding
|
||||||
self.max_val = max_val
|
self.max_val = max_val
|
||||||
self.noise_mask = noise_mask
|
|
||||||
self.masked_image = masked_image
|
|
||||||
self.noise = noise
|
|
||||||
|
|
||||||
def dynamic_thresholding_fn(self, x0, t=None):
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
"""
|
"""
|
||||||
@ -391,10 +385,7 @@ class UniPC:
|
|||||||
"""
|
"""
|
||||||
Return the noise prediction model.
|
Return the noise prediction model.
|
||||||
"""
|
"""
|
||||||
if self.noise_mask is not None:
|
return self.model(x, t)
|
||||||
return self.model(x, t) * self.noise_mask
|
|
||||||
else:
|
|
||||||
return self.model(x, t)
|
|
||||||
|
|
||||||
def data_prediction_fn(self, x, t):
|
def data_prediction_fn(self, x, t):
|
||||||
"""
|
"""
|
||||||
@ -409,8 +400,6 @@ class UniPC:
|
|||||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
x0 = torch.clamp(x0, -s, s) / s
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
if self.noise_mask is not None:
|
|
||||||
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
|
||||||
return x0
|
return x0
|
||||||
|
|
||||||
def model_fn(self, x, t):
|
def model_fn(self, x, t):
|
||||||
@ -723,8 +712,6 @@ class UniPC:
|
|||||||
assert timesteps.shape[0] - 1 == steps
|
assert timesteps.shape[0] - 1 == steps
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
for step_index in trange(steps, disable=disable_pbar):
|
for step_index in trange(steps, disable=disable_pbar):
|
||||||
if self.noise_mask is not None:
|
|
||||||
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
|
||||||
if step_index == 0:
|
if step_index == 0:
|
||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
@ -766,7 +753,7 @@ class UniPC:
|
|||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
model_prev_list[-1] = model_x
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(step_index, model_prev_list[-1], x, steps)
|
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# if denoise_to_zero:
|
# if denoise_to_zero:
|
||||||
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
|||||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||||
|
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
timesteps = sigmas[:]
|
timesteps = sigmas[:]
|
||||||
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
|
|||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
ns = SigmaConvert()
|
ns = SigmaConvert()
|
||||||
|
|
||||||
if image is not None:
|
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
|
||||||
img = image * ns.marginal_alpha(timesteps[0])
|
|
||||||
if max_denoise:
|
|
||||||
noise_mult = 1.0
|
|
||||||
else:
|
|
||||||
noise_mult = ns.marginal_std(timesteps[0])
|
|
||||||
img += noise * noise_mult
|
|
||||||
else:
|
|
||||||
img = noise
|
|
||||||
|
|
||||||
model_type = "noise"
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
|
|||||||
)
|
)
|
||||||
|
|
||||||
order = min(3, len(timesteps) - 2)
|
order = min(3, len(timesteps) - 2)
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
||||||
|
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
@ -513,14 +513,6 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
class UNIPC(Sampler):
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
|
||||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
|
||||||
|
|
||||||
class UNIPCBH2(Sampler):
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
|
||||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||||
@ -640,9 +632,9 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
|||||||
|
|
||||||
def sampler_object(name):
|
def sampler_object(name):
|
||||||
if name == "uni_pc":
|
if name == "uni_pc":
|
||||||
sampler = UNIPC()
|
sampler = KSAMPLER(uni_pc.sample_unipc)
|
||||||
elif name == "uni_pc_bh2":
|
elif name == "uni_pc_bh2":
|
||||||
sampler = UNIPCBH2()
|
sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
|
||||||
elif name == "ddim":
|
elif name == "ddim":
|
||||||
sampler = ksampler("euler", inpaint_options={"random": True})
|
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user