mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-17 17:43:30 +00:00
Compare commits
18 Commits
630545d273
...
3014c328a9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3014c328a9 | ||
![]() |
22ad513c72 | ||
![]() |
ed945a1790 | ||
![]() |
f9207c6936 | ||
![]() |
8ad7477647 | ||
![]() |
7cf6b5101c | ||
![]() |
b184c091e2 | ||
![]() |
8d9dc98fba | ||
![]() |
eb40f9377b | ||
![]() |
8a8327fa73 | ||
![]() |
6b68b61644 | ||
![]() |
812dc34f46 | ||
![]() |
c176ad8f50 | ||
![]() |
70ff03429c | ||
![]() |
ce0bff9a4b | ||
![]() |
587c93ebff | ||
![]() |
1782896502 | ||
![]() |
92a630e737 |
@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
|||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
181
comfy/k_diffusion/sa_solver.py
Normal file
181
comfy/k_diffusion/sa_solver.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
# Modify from: https://github.com/scxue/SA-Solver
|
||||||
|
# MIT license
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def get_coefficients_exponential_positive(order, interval_start, interval_end, tau):
|
||||||
|
"""
|
||||||
|
Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
|
||||||
|
For calculating the coefficient of gradient terms after the lagrange interpolation,
|
||||||
|
see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
||||||
|
For data_prediction formula.
|
||||||
|
"""
|
||||||
|
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
|
||||||
|
|
||||||
|
# after change of variable(cov)
|
||||||
|
interval_end_cov = (1 + tau ** 2) * interval_end
|
||||||
|
interval_start_cov = (1 + tau ** 2) * interval_start
|
||||||
|
|
||||||
|
if order == 0:
|
||||||
|
return (torch.exp(interval_end_cov)
|
||||||
|
* (1 - torch.exp(-(interval_end_cov - interval_start_cov)))
|
||||||
|
/ ((1 + tau ** 2))
|
||||||
|
)
|
||||||
|
elif order == 1:
|
||||||
|
return (torch.exp(interval_end_cov)
|
||||||
|
* ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)))
|
||||||
|
/ ((1 + tau ** 2) ** 2)
|
||||||
|
)
|
||||||
|
elif order == 2:
|
||||||
|
return (torch.exp(interval_end_cov)
|
||||||
|
* ((interval_end_cov ** 2 - 2 * interval_end_cov + 2)
|
||||||
|
- (interval_start_cov ** 2 - 2 * interval_start_cov + 2)
|
||||||
|
* torch.exp(-(interval_end_cov - interval_start_cov))
|
||||||
|
)
|
||||||
|
/ ((1 + tau ** 2) ** 3)
|
||||||
|
)
|
||||||
|
elif order == 3:
|
||||||
|
return (torch.exp(interval_end_cov)
|
||||||
|
* ((interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6)
|
||||||
|
- (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6)
|
||||||
|
* torch.exp(-(interval_end_cov - interval_start_cov))
|
||||||
|
)
|
||||||
|
/ ((1 + tau ** 2) ** 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
def lagrange_polynomial_coefficient(order, lambda_list):
|
||||||
|
"""
|
||||||
|
Calculate the coefficient of lagrange polynomial
|
||||||
|
For lagrange interpolation
|
||||||
|
"""
|
||||||
|
assert order in [0, 1, 2, 3]
|
||||||
|
assert order == len(lambda_list) - 1
|
||||||
|
if order == 0:
|
||||||
|
return [[1.0]]
|
||||||
|
elif order == 1:
|
||||||
|
return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
|
||||||
|
[1.0 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
|
||||||
|
elif order == 2:
|
||||||
|
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
|
||||||
|
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
|
||||||
|
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
|
||||||
|
return [[1.0 / denominator1, (-lambda_list[1] - lambda_list[2]) / denominator1, lambda_list[1] * lambda_list[2] / denominator1],
|
||||||
|
[1.0 / denominator2, (-lambda_list[0] - lambda_list[2]) / denominator2, lambda_list[0] * lambda_list[2] / denominator2],
|
||||||
|
[1.0 / denominator3, (-lambda_list[0] - lambda_list[1]) / denominator3, lambda_list[0] * lambda_list[1] / denominator3]
|
||||||
|
]
|
||||||
|
elif order == 3:
|
||||||
|
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (lambda_list[0] - lambda_list[3])
|
||||||
|
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (lambda_list[1] - lambda_list[3])
|
||||||
|
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (lambda_list[2] - lambda_list[3])
|
||||||
|
denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (lambda_list[3] - lambda_list[2])
|
||||||
|
return [[1.0 / denominator1,
|
||||||
|
(-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
|
||||||
|
(lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator1,
|
||||||
|
(-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
|
||||||
|
|
||||||
|
[1.0 / denominator2,
|
||||||
|
(-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
|
||||||
|
(lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator2,
|
||||||
|
(-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
|
||||||
|
|
||||||
|
[1.0 / denominator3,
|
||||||
|
(-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
|
||||||
|
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3,
|
||||||
|
(-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
|
||||||
|
|
||||||
|
[1.0 / denominator4,
|
||||||
|
(-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
|
||||||
|
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2]) / denominator4,
|
||||||
|
(-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_coefficients_fn(order, interval_start, interval_end, lambda_list, tau):
|
||||||
|
"""
|
||||||
|
Calculate the coefficient of gradients.
|
||||||
|
"""
|
||||||
|
assert order in [1, 2, 3, 4]
|
||||||
|
assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
|
||||||
|
lagrange_coefficient = lagrange_polynomial_coefficient(order - 1, lambda_list)
|
||||||
|
coefficients = [sum(lagrange_coefficient[i][j] * get_coefficients_exponential_positive(order - 1 - j, interval_start, interval_end, tau)
|
||||||
|
for j in range(order))
|
||||||
|
for i in range(order)]
|
||||||
|
assert len(coefficients) == order, 'the length of coefficients does not match the order'
|
||||||
|
return coefficients
|
||||||
|
|
||||||
|
def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
|
||||||
|
"""
|
||||||
|
SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
sigma_prev = sigma_prev_list[-1]
|
||||||
|
gradient_part = torch.zeros_like(x)
|
||||||
|
lambda_list = [t_fn(sigma_prev_list[-(i + 1)]) for i in range(order)]
|
||||||
|
lambda_t = t_fn(sigma)
|
||||||
|
lambda_prev = lambda_list[0]
|
||||||
|
h = lambda_t - lambda_prev
|
||||||
|
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)
|
||||||
|
|
||||||
|
if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
|
||||||
|
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
|
||||||
|
# ODE case
|
||||||
|
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
|
||||||
|
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
|
||||||
|
gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
|
||||||
|
* (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
|
||||||
|
/ (lambda_prev - lambda_list[1])
|
||||||
|
)
|
||||||
|
gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
|
||||||
|
* (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
|
||||||
|
/ (lambda_prev - lambda_list[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(order):
|
||||||
|
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
||||||
|
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
||||||
|
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
||||||
|
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||||
|
|
||||||
|
def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
|
||||||
|
"""
|
||||||
|
SA-Corrector, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
sigma_prev = sigma_prev_list[-1]
|
||||||
|
gradient_part = torch.zeros_like(x)
|
||||||
|
sigma_list = sigma_prev_list + [sigma]
|
||||||
|
lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)]
|
||||||
|
lambda_t = lambda_list[0]
|
||||||
|
lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev)
|
||||||
|
h = lambda_t - lambda_prev
|
||||||
|
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)
|
||||||
|
|
||||||
|
if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
|
||||||
|
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
|
||||||
|
# ODE case
|
||||||
|
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
|
||||||
|
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
|
||||||
|
gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
|
||||||
|
* (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
|
||||||
|
/ ((1 + tau ** 2) ** 2 * h))
|
||||||
|
)
|
||||||
|
gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
|
||||||
|
* (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
|
||||||
|
/ ((1 + tau ** 2) ** 2 * h))
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(order):
|
||||||
|
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
||||||
|
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
|
||||||
|
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
|
||||||
|
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||||
|
|
||||||
|
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
||||||
|
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
||||||
|
if eta == 0:
|
||||||
|
# Pure ODE
|
||||||
|
return 0
|
||||||
|
return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0
|
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
|
from . import sa_solver
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
@ -1140,6 +1142,91 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
|
|
||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
|
# Modify from: https://github.com/scxue/SA-Solver
|
||||||
|
# MIT license
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, pc_mode="PEC", tau_func=None, noise_sampler=None):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
if tau_func is None:
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
start_sigma = model_sampling.percent_to_sigma(0.2)
|
||||||
|
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||||
|
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
||||||
|
tau = tau_func
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
sigma_prev_list = []
|
||||||
|
model_prev_list = []
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
sigma = sigmas[i]
|
||||||
|
if i == 0:
|
||||||
|
# Init the initial values
|
||||||
|
denoised = model(x, sigma * s_in, **extra_args)
|
||||||
|
model_prev_list.append(denoised)
|
||||||
|
sigma_prev_list.append(sigma)
|
||||||
|
else:
|
||||||
|
# Lower order final
|
||||||
|
predictor_order_used = min(predictor_order, i, len(sigmas) - i - 1)
|
||||||
|
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
|
||||||
|
|
||||||
|
tau_val = tau(sigma)
|
||||||
|
noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1])
|
||||||
|
|
||||||
|
# Predictor step
|
||||||
|
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
|
||||||
|
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
|
||||||
|
noise=noise, sigma=sigma)
|
||||||
|
|
||||||
|
# Evaluation step
|
||||||
|
denoised = model(x_p, sigma * s_in, **extra_args)
|
||||||
|
model_prev_list.append(denoised)
|
||||||
|
|
||||||
|
# Corrector step
|
||||||
|
if corrector_order_used > 0:
|
||||||
|
x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
|
||||||
|
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
|
||||||
|
noise=noise, sigma=sigma)
|
||||||
|
else:
|
||||||
|
x = x_p
|
||||||
|
|
||||||
|
del noise, x_p
|
||||||
|
|
||||||
|
# Evaluation step for PECE
|
||||||
|
if corrector_order_used > 0 and pc_mode == 'PECE':
|
||||||
|
del model_prev_list[-1]
|
||||||
|
denoised = model(x, sigma * s_in, **extra_args)
|
||||||
|
model_prev_list.append(denoised)
|
||||||
|
|
||||||
|
sigma_prev_list.append(sigma)
|
||||||
|
if len(model_prev_list) > max(predictor_order, corrector_order):
|
||||||
|
del model_prev_list[0]
|
||||||
|
del sigma_prev_list[0]
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
|
||||||
|
|
||||||
|
if sigmas[-1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
return model_prev_list[-1]
|
||||||
|
return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
|
||||||
|
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
|
||||||
|
noise=0, sigma=sigmas[-1])
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
|
||||||
|
predictor_order=predictor_order, corrector_order=corrector_order,
|
||||||
|
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "er_sde"]
|
"gradient_estimation", "er_sde", "sa_solver", "sa_solver_pece"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
@ -316,3 +316,156 @@ class LRUCache(BasicCache):
|
|||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyAwareCache(BasicCache):
|
||||||
|
"""
|
||||||
|
A cache implementation that tracks dependencies between nodes and manages
|
||||||
|
their execution and caching accordingly. It extends the BasicCache class.
|
||||||
|
Nodes are removed from this cache once all of their descendants have been
|
||||||
|
executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key_class):
|
||||||
|
"""
|
||||||
|
Initialize the DependencyAwareCache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_class: The class used for generating cache keys.
|
||||||
|
"""
|
||||||
|
super().__init__(key_class)
|
||||||
|
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||||
|
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||||
|
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||||
|
|
||||||
|
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""
|
||||||
|
Clear the entire cache and rebuild the dependency graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dynprompt: The dynamic prompt object containing node information.
|
||||||
|
node_ids: List of node IDs to initialize the cache for.
|
||||||
|
is_changed_cache: Flag indicating if the cache has changed.
|
||||||
|
"""
|
||||||
|
# Clear all existing cache data
|
||||||
|
self.cache.clear()
|
||||||
|
self.subcaches.clear()
|
||||||
|
self.descendants.clear()
|
||||||
|
self.ancestors.clear()
|
||||||
|
self.executed_nodes.clear()
|
||||||
|
|
||||||
|
# Call the parent method to initialize the cache with the new prompt
|
||||||
|
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||||
|
|
||||||
|
# Rebuild the dependency graph
|
||||||
|
self._build_dependency_graph(dynprompt, node_ids)
|
||||||
|
|
||||||
|
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||||
|
"""
|
||||||
|
Build the dependency graph for all nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dynprompt: The dynamic prompt object containing node information.
|
||||||
|
node_ids: List of node IDs to build the graph for.
|
||||||
|
"""
|
||||||
|
self.descendants.clear()
|
||||||
|
self.ancestors.clear()
|
||||||
|
for node_id in node_ids:
|
||||||
|
self.descendants[node_id] = set()
|
||||||
|
self.ancestors[node_id] = set()
|
||||||
|
|
||||||
|
for node_id in node_ids:
|
||||||
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
|
for input_data in inputs.values():
|
||||||
|
if is_link(input_data): # Check if the input is a link to another node
|
||||||
|
ancestor_id = input_data[0]
|
||||||
|
self.descendants[ancestor_id].add(node_id)
|
||||||
|
self.ancestors[node_id].add(ancestor_id)
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
"""
|
||||||
|
Mark a node as executed and store its value in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node to store.
|
||||||
|
value: The value to store for the node.
|
||||||
|
"""
|
||||||
|
self._set_immediate(node_id, value)
|
||||||
|
self.executed_nodes.add(node_id)
|
||||||
|
self._cleanup_ancestors(node_id)
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
"""
|
||||||
|
Retrieve the cached value for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cached value for the node.
|
||||||
|
"""
|
||||||
|
return self._get_immediate(node_id)
|
||||||
|
|
||||||
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
"""
|
||||||
|
Ensure a subcache exists for a node and update dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the parent node.
|
||||||
|
children_ids: List of child node IDs to associate with the parent node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The subcache object for the node.
|
||||||
|
"""
|
||||||
|
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||||
|
for child_id in children_ids:
|
||||||
|
self.descendants[node_id].add(child_id)
|
||||||
|
self.ancestors[child_id].add(node_id)
|
||||||
|
return subcache
|
||||||
|
|
||||||
|
def _cleanup_ancestors(self, node_id):
|
||||||
|
"""
|
||||||
|
Check if ancestors of a node can be removed from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node whose ancestors are to be checked.
|
||||||
|
"""
|
||||||
|
for ancestor_id in self.ancestors.get(node_id, []):
|
||||||
|
if ancestor_id in self.executed_nodes:
|
||||||
|
# Remove ancestor if all its descendants have been executed
|
||||||
|
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||||
|
self._remove_node(ancestor_id)
|
||||||
|
|
||||||
|
def _remove_node(self, node_id):
|
||||||
|
"""
|
||||||
|
Remove a node from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node to remove.
|
||||||
|
"""
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
if cache_key in self.cache:
|
||||||
|
del self.cache[cache_key]
|
||||||
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
|
if subcache_key in self.subcaches:
|
||||||
|
del self.subcaches[subcache_key]
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
"""
|
||||||
|
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def recursive_debug_dump(self):
|
||||||
|
"""
|
||||||
|
Dump the cache and dependency graph for debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the cache state and dependency graph.
|
||||||
|
"""
|
||||||
|
result = super().recursive_debug_dump()
|
||||||
|
result.append({
|
||||||
|
"descendants": self.descendants,
|
||||||
|
"ancestors": self.ancestors,
|
||||||
|
"executed_nodes": list(self.executed_nodes),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from functools import partial
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
|
from comfy.k_diffusion import sa_solver
|
||||||
import latent_preview
|
import latent_preview
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -430,6 +432,35 @@ class SamplerDPMAdaptative:
|
|||||||
"s_noise":s_noise })
|
"s_noise":s_noise })
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
class SamplerSASolver:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model": ("MODEL",),
|
||||||
|
"pc_mode": (['PEC', "PECE"],),
|
||||||
|
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||||
|
"eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SAMPLER",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/samplers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
|
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent):
|
||||||
|
model_sampling = model.get_model_object('model_sampling')
|
||||||
|
start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
|
||||||
|
end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
|
||||||
|
tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
||||||
|
|
||||||
|
if pc_mode == 'PEC':
|
||||||
|
sampler_name = "sa_solver"
|
||||||
|
else:
|
||||||
|
sampler_name = "sa_solver_pece"
|
||||||
|
sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func})
|
||||||
|
return (sampler, )
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -731,6 +762,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||||
|
"SamplerSASolver": SamplerSASolver,
|
||||||
"SplitSigmas": SplitSigmas,
|
"SplitSigmas": SplitSigmas,
|
||||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||||
"FlipSigmas": FlipSigmas,
|
"FlipSigmas": FlipSigmas,
|
||||||
|
53
execution.py
53
execution.py
@ -15,7 +15,7 @@ import nodes
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -59,20 +59,27 @@ class IsChangedCache:
|
|||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
class CacheSet:
|
|
||||||
def __init__(self, lru_size=None):
|
|
||||||
if lru_size is None or lru_size == 0:
|
|
||||||
self.init_classic_cache()
|
|
||||||
else:
|
|
||||||
self.init_lru_cache(lru_size)
|
|
||||||
self.all = [self.outputs, self.ui, self.objects]
|
|
||||||
|
|
||||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
class CacheType(Enum):
|
||||||
# blowing away the cache every time
|
CLASSIC = 0
|
||||||
def init_lru_cache(self, cache_size):
|
LRU = 1
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
DEPENDENCY_AWARE = 2
|
||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
|
||||||
|
class CacheSet:
|
||||||
|
def __init__(self, cache_type=None, cache_size=None):
|
||||||
|
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||||
|
self.init_dependency_aware_cache()
|
||||||
|
logging.info("Disabling intermediate node cache.")
|
||||||
|
elif cache_type == CacheType.LRU:
|
||||||
|
if cache_size is None:
|
||||||
|
cache_size = 0
|
||||||
|
self.init_lru_cache(cache_size)
|
||||||
|
logging.info("Using LRU cache")
|
||||||
|
else:
|
||||||
|
self.init_classic_cache()
|
||||||
|
|
||||||
|
self.all = [self.outputs, self.ui, self.objects]
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
@ -80,6 +87,17 @@ class CacheSet:
|
|||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def init_lru_cache(self, cache_size):
|
||||||
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
# only hold cached items while the decendents have not executed
|
||||||
|
def init_dependency_aware_cache(self):
|
||||||
|
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||||
|
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||||
|
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
"outputs": self.outputs.recursive_debug_dump(),
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, lru_size=None):
|
def __init__(self, server, cache_type=False, cache_size=None):
|
||||||
self.lru_size = lru_size
|
self.cache_size = cache_size
|
||||||
|
self.cache_type = cache_type
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(self.lru_size)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
|
8
main.py
8
main.py
@ -156,7 +156,13 @@ def cuda_malloc_warning():
|
|||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
cache_type = execution.CacheType.CLASSIC
|
||||||
|
if args.cache_lru > 0:
|
||||||
|
cache_type = execution.CacheType.LRU
|
||||||
|
elif args.cache_none:
|
||||||
|
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||||
|
|
||||||
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.14.6
|
comfyui-frontend-package==1.15.13
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
|
|||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
if request.path.endswith('.js') or request.path.endswith('.css'):
|
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user