Add VPScheduler node

This commit is contained in:
comfyanonymous 2023-10-01 03:48:07 -04:00
parent 8ab49dc0a4
commit 2ef459b1d4

View File

@ -80,6 +80,25 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, ) return (sigmas, )
class VPScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values
"beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
"eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, beta_d, beta_min, eps_s):
sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
return (sigmas, )
class SplitSigmas: class SplitSigmas:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -217,6 +236,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler, "KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,