From 22ad513c72b891322f7baf6b459aa41858087b3b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 11 Apr 2025 07:16:52 -0400 Subject: [PATCH] Refactor node cache code to more easily add other types of cache. --- execution.py | 30 +++++++++++++++++++++--------- main.py | 8 +++++++- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 7431c100d..9a5e27771 100644 --- a/execution.py +++ b/execution.py @@ -59,14 +59,26 @@ class IsChangedCache: self.is_changed[node_id] = node["is_changed"] return self.is_changed[node_id] + +class CacheType(Enum): + CLASSIC = 0 + LRU = 1 + DEPENDENCY_AWARE = 2 + + class CacheSet: - def __init__(self, lru_size=None, cache_none=False): - if cache_none: + def __init__(self, cache_type=None, cache_size=None): + if cache_type == CacheType.DEPENDENCY_AWARE: self.init_dependency_aware_cache() - elif lru_size is None or lru_size == 0: - self.init_classic_cache() + logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.LRU: + if cache_size is None: + cache_size = 0 + self.init_lru_cache(cache_size) + logging.info("Using LRU cache") else: - self.init_lru_cache(lru_size) + self.init_classic_cache() + self.all = [self.outputs, self.ui, self.objects] # Performs like the old cache -- dump data ASAP @@ -420,14 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None, cache_none=False): - self.lru_size = lru_size - self.cache_none = cache_none + def __init__(self, server, cache_type=False, cache_size=None): + self.cache_size = cache_size + self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(self.lru_size, self.cache_none) + self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.status_messages = [] self.success = True diff --git a/main.py b/main.py index e72e7c567..4780a9c69 100644 --- a/main.py +++ b/main.py @@ -156,7 +156,13 @@ def cuda_malloc_warning(): def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none) + cache_type = execution.CacheType.CLASSIC + if args.cache_lru > 0: + cache_type = execution.CacheType.LRU + elif args.cache_none: + cache_type = execution.CacheType.DEPENDENCY_AWARE + + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0