Add masks to samplers code for inpainting.

This commit is contained in:
comfyanonymous 2023-02-15 01:49:17 -05:00
parent c1d5810020
commit 07db00355f
2 changed files with 48 additions and 15 deletions

View File

@ -358,7 +358,10 @@ class UniPC:
predict_x0=True,
thresholding=False,
max_val=1.,
variant='bh1'
variant='bh1',
noise_mask=None,
masked_image=None,
noise=None,
):
"""Construct a UniPC.
@ -370,6 +373,9 @@ class UniPC:
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.noise_mask = noise_mask
self.masked_image = masked_image
self.noise = noise
def dynamic_thresholding_fn(self, x0, t=None):
"""
@ -386,7 +392,10 @@ class UniPC:
"""
Return the noise prediction model.
"""
return self.model(x, t)
if self.noise_mask is not None:
return self.model(x, t) * self.noise_mask
else:
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
@ -401,6 +410,8 @@ class UniPC:
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
if self.noise_mask is not None:
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
return x0
def model_fn(self, x, t):
@ -713,6 +724,8 @@ class UniPC:
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
@ -820,7 +833,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None):
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None):
to_zero = False
if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
model_kwargs=extra_args,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
if not to_zero:
x /= ns.marginal_alpha(timesteps[-1])

View File

@ -139,8 +139,17 @@ class CFGDenoiserComplex(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
if denoise_mask is not None:
out *= denoise_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
sigs = []
@ -200,8 +209,8 @@ class KSampler:
sampler = self.SAMPLERS[0]
self.scheduler = scheduler
self.sampler = sampler
self.sigma_min=float(self.model_wrap.sigmas[0])
self.sigma_max=float(self.model_wrap.sigmas[-1])
self.sigma_min=float(self.model_wrap.sigma_min)
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
def _calculate_sigmas(self, steps):
@ -235,7 +244,7 @@ class KSampler:
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False):
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
sigmas = self.sigmas
sigma_min = self.sigma_min
@ -267,17 +276,28 @@ class KSampler:
else:
precision_scope = contextlib.nullcontext
latent_mask = None
if denoise_mask is not None:
latent_mask = (torch.ones_like(denoise_mask) - denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
with precision_scope(self.device):
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
else:
noise *= sigmas[0]
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise
noise = noise * sigmas[0]
if latent_image is not None:
noise += latent_image
if self.sampler == "sample_dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
elif self.sampler == "sample_dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
else:
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args)
return samples.to(torch.float32)