diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 659c5c62..cc3153bf 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -833,7 +833,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -847,7 +847,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod) if image is not None: - img = image * ns.marginal_alpha(timesteps[0]) + noise * ns.marginal_std(timesteps[0]) + img = image * ns.marginal_alpha(timesteps[0]) + if max_denoise: + noise_mult = 1.0 + else: + noise_mult = ns.marginal_std(timesteps[0]) + img += noise * noise_mult else: img = noise diff --git a/comfy/samplers.py b/comfy/samplers.py index dd6385b2..2dc5a531 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -334,6 +334,7 @@ class KSampler: self.sigma_min=float(self.model_wrap.sigma_min) self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) + self.denoise = denoise def _calculate_sigmas(self, steps): sigmas = None @@ -417,11 +418,16 @@ class KSampler: cond_concat.append(blank_inpaint_image_like(noise)) extra_args["cond_concat"] = cond_concat + if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0): + max_denoise = False + else: + max_denoise = True + with precision_scope(self.device): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') else: extra_args["denoise_mask"] = denoise_mask self.model_k.latent_image = latent_image