mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Clean up and refactor sampler code.
This should make it much easier to write custom nodes with kdiffusion type samplers.
This commit is contained in:
parent
94cc718e9c
commit
420beeeb05
@ -522,42 +522,59 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral"
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||||
|
|
||||||
|
class KSAMPLER(Sampler):
|
||||||
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
self.sampler_function = sampler_function
|
||||||
|
self.extra_options = extra_options
|
||||||
|
self.inpaint_options = inpaint_options
|
||||||
|
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
|
model_k = KSamplerX0Inpaint(model_wrap)
|
||||||
|
model_k.latent_image = latent_image
|
||||||
|
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||||
|
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||||
|
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||||
|
else:
|
||||||
|
model_k.noise = noise
|
||||||
|
|
||||||
|
if self.max_denoise(model_wrap, sigmas):
|
||||||
|
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
else:
|
||||||
|
noise = noise * sigmas[0]
|
||||||
|
|
||||||
|
k_callback = None
|
||||||
|
total_steps = len(sigmas) - 1
|
||||||
|
if callback is not None:
|
||||||
|
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
noise += latent_image
|
||||||
|
|
||||||
|
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||||
class KSAMPLER(Sampler):
|
if sampler_name == "dpm_fast":
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
|
||||||
model_k = KSamplerX0Inpaint(model_wrap)
|
|
||||||
model_k.latent_image = latent_image
|
|
||||||
if inpaint_options.get("random", False): #TODO: Should this be the default?
|
|
||||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
|
||||||
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
|
||||||
else:
|
|
||||||
model_k.noise = noise
|
|
||||||
|
|
||||||
if self.max_denoise(model_wrap, sigmas):
|
|
||||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
|
||||||
else:
|
|
||||||
noise = noise * sigmas[0]
|
|
||||||
|
|
||||||
k_callback = None
|
|
||||||
total_steps = len(sigmas) - 1
|
|
||||||
if callback is not None:
|
|
||||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
|
||||||
|
|
||||||
sigma_min = sigmas[-1]
|
sigma_min = sigmas[-1]
|
||||||
if sigma_min == 0:
|
if sigma_min == 0:
|
||||||
sigma_min = sigmas[-2]
|
sigma_min = sigmas[-2]
|
||||||
|
total_steps = len(sigmas) - 1
|
||||||
|
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
|
||||||
|
sampler_function = dpm_fast_function
|
||||||
|
elif sampler_name == "dpm_adaptive":
|
||||||
|
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
|
||||||
|
sigma_min = sigmas[-1]
|
||||||
|
if sigma_min == 0:
|
||||||
|
sigma_min = sigmas[-2]
|
||||||
|
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
|
||||||
|
sampler_function = dpm_adaptive_function
|
||||||
|
else:
|
||||||
|
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||||
|
|
||||||
if latent_image is not None:
|
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||||
noise += latent_image
|
|
||||||
if sampler_name == "dpm_fast":
|
|
||||||
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
elif sampler_name == "dpm_adaptive":
|
|
||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
else:
|
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
|
|
||||||
return samples
|
|
||||||
return KSAMPLER
|
|
||||||
|
|
||||||
def wrap_model(model):
|
def wrap_model(model):
|
||||||
model_denoise = CFGNoisePredictor(model)
|
model_denoise = CFGNoisePredictor(model)
|
||||||
@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
|||||||
print("error invalid scheduler", self.scheduler)
|
print("error invalid scheduler", self.scheduler)
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def sampler_class(name):
|
def sampler_object(name):
|
||||||
if name == "uni_pc":
|
if name == "uni_pc":
|
||||||
sampler = UNIPC
|
sampler = UNIPC()
|
||||||
elif name == "uni_pc_bh2":
|
elif name == "uni_pc_bh2":
|
||||||
sampler = UNIPCBH2
|
sampler = UNIPCBH2()
|
||||||
elif name == "ddim":
|
elif name == "ddim":
|
||||||
sampler = ksampler("euler", inpaint_options={"random": True})
|
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||||
else:
|
else:
|
||||||
@ -687,6 +704,6 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
return torch.zeros_like(noise)
|
return torch.zeros_like(noise)
|
||||||
|
|
||||||
sampler = sampler_class(self.sampler)
|
sampler = sampler_object(self.sampler)
|
||||||
|
|
||||||
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
|
@ -149,7 +149,7 @@ class KSamplerSelect:
|
|||||||
FUNCTION = "get_sampler"
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
def get_sampler(self, sampler_name):
|
def get_sampler(self, sampler_name):
|
||||||
sampler = comfy.samplers.sampler_class(sampler_name)()
|
sampler = comfy.samplers.sampler_object(sampler_name)
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
class SamplerDPMPP_2M_SDE:
|
class SamplerDPMPP_2M_SDE:
|
||||||
@ -172,7 +172,7 @@ class SamplerDPMPP_2M_SDE:
|
|||||||
sampler_name = "dpmpp_2m_sde"
|
sampler_name = "dpmpp_2m_sde"
|
||||||
else:
|
else:
|
||||||
sampler_name = "dpmpp_2m_sde_gpu"
|
sampler_name = "dpmpp_2m_sde_gpu"
|
||||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
|
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ class SamplerDPMPP_SDE:
|
|||||||
sampler_name = "dpmpp_sde"
|
sampler_name = "dpmpp_sde"
|
||||||
else:
|
else:
|
||||||
sampler_name = "dpmpp_sde_gpu"
|
sampler_name = "dpmpp_sde_gpu"
|
||||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})()
|
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
class SamplerCustom:
|
class SamplerCustom:
|
||||||
|
Loading…
Reference in New Issue
Block a user