mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Add masks to samplers code for inpainting.
This commit is contained in:
parent
c1d5810020
commit
07db00355f
@ -358,7 +358,10 @@ class UniPC:
|
|||||||
predict_x0=True,
|
predict_x0=True,
|
||||||
thresholding=False,
|
thresholding=False,
|
||||||
max_val=1.,
|
max_val=1.,
|
||||||
variant='bh1'
|
variant='bh1',
|
||||||
|
noise_mask=None,
|
||||||
|
masked_image=None,
|
||||||
|
noise=None,
|
||||||
):
|
):
|
||||||
"""Construct a UniPC.
|
"""Construct a UniPC.
|
||||||
|
|
||||||
@ -370,6 +373,9 @@ class UniPC:
|
|||||||
self.predict_x0 = predict_x0
|
self.predict_x0 = predict_x0
|
||||||
self.thresholding = thresholding
|
self.thresholding = thresholding
|
||||||
self.max_val = max_val
|
self.max_val = max_val
|
||||||
|
self.noise_mask = noise_mask
|
||||||
|
self.masked_image = masked_image
|
||||||
|
self.noise = noise
|
||||||
|
|
||||||
def dynamic_thresholding_fn(self, x0, t=None):
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
"""
|
"""
|
||||||
@ -386,6 +392,9 @@ class UniPC:
|
|||||||
"""
|
"""
|
||||||
Return the noise prediction model.
|
Return the noise prediction model.
|
||||||
"""
|
"""
|
||||||
|
if self.noise_mask is not None:
|
||||||
|
return self.model(x, t) * self.noise_mask
|
||||||
|
else:
|
||||||
return self.model(x, t)
|
return self.model(x, t)
|
||||||
|
|
||||||
def data_prediction_fn(self, x, t):
|
def data_prediction_fn(self, x, t):
|
||||||
@ -401,6 +410,8 @@ class UniPC:
|
|||||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
x0 = torch.clamp(x0, -s, s) / s
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
|
if self.noise_mask is not None:
|
||||||
|
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
||||||
return x0
|
return x0
|
||||||
|
|
||||||
def model_fn(self, x, t):
|
def model_fn(self, x, t):
|
||||||
@ -713,6 +724,8 @@ class UniPC:
|
|||||||
assert timesteps.shape[0] - 1 == steps
|
assert timesteps.shape[0] - 1 == steps
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
for step_index in trange(steps):
|
for step_index in trange(steps):
|
||||||
|
if self.noise_mask is not None:
|
||||||
|
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
||||||
if step_index == 0:
|
if step_index == 0:
|
||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
@ -820,7 +833,7 @@ def expand_dims(v, dims):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None):
|
||||||
to_zero = False
|
to_zero = False
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
||||||
@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
|||||||
model_kwargs=extra_args,
|
model_kwargs=extra_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
||||||
if not to_zero:
|
if not to_zero:
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
|
@ -139,8 +139,17 @@ class CFGDenoiserComplex(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask):
|
||||||
return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
|
if denoise_mask is not None:
|
||||||
|
latent_mask = 1. - denoise_mask
|
||||||
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
||||||
|
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out *= denoise_mask
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out += self.latent_image * latent_mask
|
||||||
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model, steps):
|
def simple_scheduler(model, steps):
|
||||||
sigs = []
|
sigs = []
|
||||||
@ -200,8 +209,8 @@ class KSampler:
|
|||||||
sampler = self.SAMPLERS[0]
|
sampler = self.SAMPLERS[0]
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.sigma_min=float(self.model_wrap.sigmas[0])
|
self.sigma_min=float(self.model_wrap.sigma_min)
|
||||||
self.sigma_max=float(self.model_wrap.sigmas[-1])
|
self.sigma_max=float(self.model_wrap.sigma_max)
|
||||||
self.set_steps(steps, denoise)
|
self.set_steps(steps, denoise)
|
||||||
|
|
||||||
def _calculate_sigmas(self, steps):
|
def _calculate_sigmas(self, steps):
|
||||||
@ -235,7 +244,7 @@ class KSampler:
|
|||||||
self.sigmas = sigmas[-(steps + 1):]
|
self.sigmas = sigmas[-(steps + 1):]
|
||||||
|
|
||||||
|
|
||||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False):
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
|
||||||
sigmas = self.sigmas
|
sigmas = self.sigmas
|
||||||
sigma_min = self.sigma_min
|
sigma_min = self.sigma_min
|
||||||
|
|
||||||
@ -267,17 +276,28 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
|
latent_mask = None
|
||||||
|
if denoise_mask is not None:
|
||||||
|
latent_mask = (torch.ones_like(denoise_mask) - denoise_mask)
|
||||||
|
|
||||||
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||||
with precision_scope(self.device):
|
with precision_scope(self.device):
|
||||||
if self.sampler == "uni_pc":
|
if self.sampler == "uni_pc":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
||||||
else:
|
else:
|
||||||
noise *= sigmas[0]
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
|
self.model_k.latent_image = latent_image
|
||||||
|
self.model_k.noise = noise
|
||||||
|
|
||||||
|
noise = noise * sigmas[0]
|
||||||
|
|
||||||
if latent_image is not None:
|
if latent_image is not None:
|
||||||
noise += latent_image
|
noise += latent_image
|
||||||
if self.sampler == "sample_dpm_fast":
|
if self.sampler == "sample_dpm_fast":
|
||||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
|
||||||
elif self.sampler == "sample_dpm_adaptive":
|
elif self.sampler == "sample_dpm_adaptive":
|
||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
|
||||||
else:
|
else:
|
||||||
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args)
|
||||||
|
|
||||||
return samples.to(torch.float32)
|
return samples.to(torch.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user