From 07db00355f890c095a1137feed103e360914e7bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 01:49:17 -0500 Subject: [PATCH] Add masks to samplers code for inpainting. --- comfy/extra_samplers/uni_pc.py | 23 ++++++++++++++----- comfy/samplers.py | 40 +++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index ae3a544ef..cfd7225b3 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -358,7 +358,10 @@ class UniPC: predict_x0=True, thresholding=False, max_val=1., - variant='bh1' + variant='bh1', + noise_mask=None, + masked_image=None, + noise=None, ): """Construct a UniPC. @@ -370,7 +373,10 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - + self.noise_mask = noise_mask + self.masked_image = masked_image + self.noise = noise + def dynamic_thresholding_fn(self, x0, t=None): """ The dynamic thresholding method. @@ -386,7 +392,10 @@ class UniPC: """ Return the noise prediction model. """ - return self.model(x, t) + if self.noise_mask is not None: + return self.model(x, t) * self.noise_mask + else: + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -401,6 +410,8 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s + if self.noise_mask is not None: + x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -713,6 +724,8 @@ class UniPC: assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps): + if self.noise_mask is not None: + x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -820,7 +833,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None): +def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None model_kwargs=extra_args, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7f6dc972a..b806381ea 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -139,8 +139,17 @@ class CFGDenoiserComplex(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale): - return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask): + if denoise_mask is not None: + latent_mask = 1. - denoise_mask + x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask + out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + if denoise_mask is not None: + out *= denoise_mask + + if denoise_mask is not None: + out += self.latent_image * latent_mask + return out def simple_scheduler(model, steps): sigs = [] @@ -200,8 +209,8 @@ class KSampler: sampler = self.SAMPLERS[0] self.scheduler = scheduler self.sampler = sampler - self.sigma_min=float(self.model_wrap.sigmas[0]) - self.sigma_max=float(self.model_wrap.sigmas[-1]) + self.sigma_min=float(self.model_wrap.sigma_min) + self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) def _calculate_sigmas(self, steps): @@ -235,7 +244,7 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): sigmas = self.sigmas sigma_min = self.sigma_min @@ -267,17 +276,28 @@ class KSampler: else: precision_scope = contextlib.nullcontext + latent_mask = None + if denoise_mask is not None: + latent_mask = (torch.ones_like(denoise_mask) - denoise_mask) + + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} with precision_scope(self.device): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) else: - noise *= sigmas[0] + extra_args["denoise_mask"] = denoise_mask + self.model_k.latent_image = latent_image + self.model_k.noise = noise + + noise = noise * sigmas[0] + if latent_image is not None: noise += latent_image if self.sampler == "sample_dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) elif self.sampler == "sample_dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) else: - samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args) + return samples.to(torch.float32)