Add node to extend sigmas (#7901)

* Add ExpandSigmas node

* Rename, add interpolation functions

Co-authored-by: liesen <liesen.dev@gmail.com>

* Move computed interpolation outside loop

* Add type hints

---------

Co-authored-by: liesen <liesen.dev@gmail.com>
This commit is contained in:
catboxanon 2025-05-02 05:28:05 -04:00 committed by GitHub
parent ff99861650
commit 551fe8dcee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import math
import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
@ -249,6 +250,55 @@ class SetFirstSigma:
sigmas[0] = sigma
return (sigmas, )
class ExtendIntermediateSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
"steps": ("INT", {"default": 2, "min": 1, "max": 100}),
"start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
"end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
"spacing": (['linear', 'cosine', 'sine'],),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "extend"
def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
if start_at_sigma < 0:
start_at_sigma = float("inf")
interpolator = {
'linear': lambda x: x,
'cosine': lambda x: torch.sin(x*math.pi/2),
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
}[spacing]
# linear space for our interpolation function
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
computed_spacing = interpolator(x)
extended_sigmas = []
for i in range(len(sigmas) - 1):
sigma_current = sigmas[i]
sigma_next = sigmas[i+1]
extended_sigmas.append(sigma_current)
if end_at_sigma <= sigma_current <= start_at_sigma:
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
extended_sigmas.extend(interpolated_steps.tolist())
# Add the last sigma value
if len(sigmas) > 0:
extended_sigmas.append(sigmas[-1])
extended_sigmas = torch.FloatTensor(extended_sigmas)
return (extended_sigmas,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -735,6 +785,7 @@ NODE_CLASS_MAPPINGS = {
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma,
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,