mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add uni_pc bh2 variant.
This commit is contained in:
parent
928087184e
commit
a7328e4945
@ -833,7 +833,7 @@ def expand_dims(v, dims):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'):
|
||||||
to_zero = False
|
to_zero = False
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
||||||
@ -870,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
|||||||
model_kwargs=extra_args,
|
model_kwargs=extra_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
||||||
if not to_zero:
|
if not to_zero:
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
|
@ -313,7 +313,7 @@ class KSampler:
|
|||||||
SCHEDULERS = ["karras", "normal", "simple"]
|
SCHEDULERS = ["karras", "normal", "simple"]
|
||||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||||
"sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde",
|
"sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde",
|
||||||
"sample_dpmpp_2m", "uni_pc"]
|
"sample_dpmpp_2m", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -420,6 +420,8 @@ class KSampler:
|
|||||||
with precision_scope(self.device):
|
with precision_scope(self.device):
|
||||||
if self.sampler == "uni_pc":
|
if self.sampler == "uni_pc":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
||||||
|
elif self.sampler == "uni_pc_bh2":
|
||||||
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
|
||||||
else:
|
else:
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
self.model_k.latent_image = latent_image
|
self.model_k.latent_image = latent_image
|
||||||
|
Loading…
Reference in New Issue
Block a user