diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index bcc7c0f20..659c5c62f 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -833,7 +833,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None): +def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -870,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None model_kwargs=extra_args, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) diff --git a/comfy/samplers.py b/comfy/samplers.py index a5a318111..dd6385b27 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -313,7 +313,7 @@ class KSampler: SCHEDULERS = ["karras", "normal", "simple"] SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral", "sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde", - "sample_dpmpp_2m", "uni_pc"] + "sample_dpmpp_2m", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): self.model = model @@ -420,6 +420,8 @@ class KSampler: with precision_scope(self.device): if self.sampler == "uni_pc": samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) + elif self.sampler == "uni_pc_bh2": + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') else: extra_args["denoise_mask"] = denoise_mask self.model_k.latent_image = latent_image