Refactor node cache code to more easily add other types of cache.

This commit is contained in:
comfyanonymous 2025-04-11 07:16:52 -04:00
parent ed945a1790
commit 22ad513c72
2 changed files with 28 additions and 10 deletions

View File

@ -59,14 +59,26 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
class CacheSet:
def __init__(self, lru_size=None, cache_none=False):
if cache_none:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
elif lru_size is None or lru_size == 0:
self.init_classic_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_lru_cache(lru_size)
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
# Performs like the old cache -- dump data ASAP
@ -420,14 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, lru_size=None, cache_none=False):
self.lru_size = lru_size
self.cache_none = cache_none
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(self.lru_size, self.cache_none)
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.status_messages = []
self.success = True

View File

@ -156,7 +156,13 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none)
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0