Compare commits

...

18 Commits

Author SHA1 Message Date
chaObserv
3014c328a9
Merge 7cf6b5101c into 22ad513c72 2025-04-11 09:49:48 -04:00
comfyanonymous
22ad513c72 Refactor node cache code to more easily add other types of cache. 2025-04-11 07:16:52 -04:00
Chargeuk
ed945a1790
Dependency Aware Node Caching for low RAM/VRAM machines (#7509)
* add dependency aware cache that removed a cached node as soon as all of its decendents have executed. This allows users with lower RAM to run workflows they would otherwise not be able to run. The downside is that every workflow will fully run each time even if no nodes have changed.

* remove test code

* tidy code
2025-04-11 06:55:51 -04:00
Chenlei Hu
f9207c6936
Update frontend to 1.15 (#7564) 2025-04-11 06:46:20 -04:00
Christian Byrne
8ad7477647
dont cache templates index (#7569) 2025-04-11 06:06:53 -04:00
chaObserv
7cf6b5101c
Merge branch 'master' into sa_solver 2025-03-16 10:28:26 +08:00
chaObserv
b184c091e2
Merge branch 'master' into sa_solver 2025-01-19 18:31:38 +08:00
chaObserv
8d9dc98fba Remove more space 2025-01-06 19:35:31 +08:00
chaObserv
eb40f9377b Remove space 2025-01-06 19:31:19 +08:00
chaObserv
8a8327fa73 change to direct return 2025-01-06 18:58:27 +08:00
chaObserv
6b68b61644 Use default_noise_sampler instead 2025-01-06 18:55:29 +08:00
chaObserv
812dc34f46
Merge branch 'comfyanonymous:master' into sa_solver 2025-01-06 18:09:03 +08:00
chaObserv
c176ad8f50 Clean up and remove modifying zero sigma 2024-10-30 01:16:34 +08:00
chaObserv
70ff03429c
Merge branch 'comfyanonymous:master' into sa_solver 2024-10-30 00:41:46 +08:00
chaObserv
ce0bff9a4b
Merge branch 'master' into sa_solver 2024-10-13 20:44:03 +08:00
chaObserv
587c93ebff Start lower_final_step in predictor earlier to stabilize 2024-09-22 18:27:13 +08:00
chaObserv
1782896502 Fix lambda_prev in corrector 2024-09-22 17:08:50 +08:00
chaObserv
92a630e737 Add SA-Solver 2024-09-22 17:08:50 +08:00
10 changed files with 500 additions and 21 deletions

View File

@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -0,0 +1,181 @@
# Modify from: https://github.com/scxue/SA-Solver
# MIT license
import torch
def get_coefficients_exponential_positive(order, interval_start, interval_end, tau):
"""
Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
For calculating the coefficient of gradient terms after the lagrange interpolation,
see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
For data_prediction formula.
"""
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
# after change of variable(cov)
interval_end_cov = (1 + tau ** 2) * interval_end
interval_start_cov = (1 + tau ** 2) * interval_start
if order == 0:
return (torch.exp(interval_end_cov)
* (1 - torch.exp(-(interval_end_cov - interval_start_cov)))
/ ((1 + tau ** 2))
)
elif order == 1:
return (torch.exp(interval_end_cov)
* ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)))
/ ((1 + tau ** 2) ** 2)
)
elif order == 2:
return (torch.exp(interval_end_cov)
* ((interval_end_cov ** 2 - 2 * interval_end_cov + 2)
- (interval_start_cov ** 2 - 2 * interval_start_cov + 2)
* torch.exp(-(interval_end_cov - interval_start_cov))
)
/ ((1 + tau ** 2) ** 3)
)
elif order == 3:
return (torch.exp(interval_end_cov)
* ((interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6)
- (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6)
* torch.exp(-(interval_end_cov - interval_start_cov))
)
/ ((1 + tau ** 2) ** 4)
)
def lagrange_polynomial_coefficient(order, lambda_list):
"""
Calculate the coefficient of lagrange polynomial
For lagrange interpolation
"""
assert order in [0, 1, 2, 3]
assert order == len(lambda_list) - 1
if order == 0:
return [[1.0]]
elif order == 1:
return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
[1.0 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
elif order == 2:
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
return [[1.0 / denominator1, (-lambda_list[1] - lambda_list[2]) / denominator1, lambda_list[1] * lambda_list[2] / denominator1],
[1.0 / denominator2, (-lambda_list[0] - lambda_list[2]) / denominator2, lambda_list[0] * lambda_list[2] / denominator2],
[1.0 / denominator3, (-lambda_list[0] - lambda_list[1]) / denominator3, lambda_list[0] * lambda_list[1] / denominator3]
]
elif order == 3:
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (lambda_list[0] - lambda_list[3])
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (lambda_list[1] - lambda_list[3])
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (lambda_list[2] - lambda_list[3])
denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (lambda_list[3] - lambda_list[2])
return [[1.0 / denominator1,
(-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
(lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator1,
(-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
[1.0 / denominator2,
(-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
(lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator2,
(-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
[1.0 / denominator3,
(-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3,
(-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
[1.0 / denominator4,
(-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2]) / denominator4,
(-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
]
def get_coefficients_fn(order, interval_start, interval_end, lambda_list, tau):
"""
Calculate the coefficient of gradients.
"""
assert order in [1, 2, 3, 4]
assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
lagrange_coefficient = lagrange_polynomial_coefficient(order - 1, lambda_list)
coefficients = [sum(lagrange_coefficient[i][j] * get_coefficients_exponential_positive(order - 1 - j, interval_start, interval_end, tau)
for j in range(order))
for i in range(order)]
assert len(coefficients) == order, 'the length of coefficients does not match the order'
return coefficients
def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
"""
SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
"""
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
t_fn = lambda sigma: sigma.log().neg()
sigma_prev = sigma_prev_list[-1]
gradient_part = torch.zeros_like(x)
lambda_list = [t_fn(sigma_prev_list[-(i + 1)]) for i in range(order)]
lambda_t = t_fn(sigma)
lambda_prev = lambda_list[0]
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)
if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
# ODE case
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
/ (lambda_prev - lambda_list[1])
)
gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
/ (lambda_prev - lambda_list[1])
)
for i in range(order):
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
"""
SA-Corrector, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
"""
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
t_fn = lambda sigma: sigma.log().neg()
sigma_prev = sigma_prev_list[-1]
gradient_part = torch.zeros_like(x)
sigma_list = sigma_prev_list + [sigma]
lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)]
lambda_t = lambda_list[0]
lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev)
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)
if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
# ODE case
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
/ ((1 + tau ** 2) ** 2 * h))
)
gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
/ ((1 + tau ** 2) ** 2 * h))
)
for i in range(order):
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
if eta == 0:
# Pure ODE
return 0
return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0

View File

@ -1,4 +1,5 @@
import math import math
from functools import partial
from scipy import integrate from scipy import integrate
import torch import torch
@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils from . import utils
from . import deis from . import deis
from . import sa_solver
import comfy.model_patcher import comfy.model_patcher
import comfy.model_sampling import comfy.model_sampling
@ -1140,6 +1142,91 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next return x_next
# Modify from: https://github.com/scxue/SA-Solver
# MIT license
@torch.no_grad()
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, pc_mode="PEC", tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
if tau_func is None:
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
start_sigma = model_sampling.percent_to_sigma(0.2)
end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
tau = tau_func
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_prev_list = []
model_prev_list = []
for i in trange(len(sigmas) - 1, disable=disable):
sigma = sigmas[i]
if i == 0:
# Init the initial values
denoised = model(x, sigma * s_in, **extra_args)
model_prev_list.append(denoised)
sigma_prev_list.append(sigma)
else:
# Lower order final
predictor_order_used = min(predictor_order, i, len(sigmas) - i - 1)
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
tau_val = tau(sigma)
noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1])
# Predictor step
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)
# Evaluation step
denoised = model(x_p, sigma * s_in, **extra_args)
model_prev_list.append(denoised)
# Corrector step
if corrector_order_used > 0:
x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)
else:
x = x_p
del noise, x_p
# Evaluation step for PECE
if corrector_order_used > 0 and pc_mode == 'PECE':
del model_prev_list[-1]
denoised = model(x, sigma * s_in, **extra_args)
model_prev_list.append(denoised)
sigma_prev_list.append(sigma)
if len(model_prev_list) > max(predictor_order, corrector_order):
del model_prev_list[0]
del sigma_prev_list[0]
if callback is not None:
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
if sigmas[-1] == 0:
# Denoising step
return model_prev_list[-1]
return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
predictor_order=predictor_order, corrector_order=corrector_order,
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
)
@torch.no_grad() @torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args

View File

@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "er_sde"] "gradient_estimation", "er_sde", "sa_solver", "sa_solver_pece"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -316,3 +316,156 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

View File

@ -1,6 +1,8 @@
from functools import partial
import comfy.samplers import comfy.samplers
import comfy.sample import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.k_diffusion import sa_solver
import latent_preview import latent_preview
import torch import torch
import comfy.utils import comfy.utils
@ -430,6 +432,35 @@ class SamplerDPMAdaptative:
"s_noise":s_noise }) "s_noise":s_noise })
return (sampler, ) return (sampler, )
class SamplerSASolver:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"pc_mode": (['PEC', "PECE"],),
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent):
model_sampling = model.get_model_object('model_sampling')
start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
if pc_mode == 'PEC':
sampler_name = "sa_solver"
else:
sampler_name = "sa_solver_pece"
sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func})
return (sampler, )
class Noise_EmptyNoise: class Noise_EmptyNoise:
def __init__(self): def __init__(self):
self.seed = 0 self.seed = 0
@ -731,6 +762,7 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
"SamplerDPMAdaptative": SamplerDPMAdaptative, "SamplerDPMAdaptative": SamplerDPMAdaptative,
"SamplerSASolver": SamplerSASolver,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise, "SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,

View File

@ -15,7 +15,7 @@ import nodes
import comfy.model_management import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -59,20 +59,27 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"] self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id] return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without class CacheType(Enum):
# blowing away the cache every time CLASSIC = 0
def init_lru_cache(self, cache_size): LRU = 1
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) DEPENDENCY_AWARE = 2
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
@ -80,6 +87,17 @@ class CacheSet:
self.ui = HierarchicalCache(CacheKeySetInputSignature) self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { result = {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:
def __init__(self, server, lru_size=None): def __init__(self, server, cache_type=False, cache_size=None):
self.lru_size = lru_size self.cache_size = cache_size
self.cache_type = cache_type
self.server = server self.server = server
self.reset() self.reset()
def reset(self): def reset(self):
self.caches = CacheSet(self.lru_size) self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.status_messages = [] self.status_messages = []
self.success = True self.success = True

View File

@ -156,7 +156,13 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.14.6 comfyui-frontend-package==1.15.13
torch torch
torchsde torchsde
torchvision torchvision

View File

@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request) response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css'): if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache') response.headers.setdefault('Cache-Control', 'no-cache')
return response return response