diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 62079e6a7..79ecbd682 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") +cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 630f280fc..dbb37b89f 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -316,3 +316,156 @@ class LRUCache(BasicCache): self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + +class DependencyAwareCache(BasicCache): + """ + A cache implementation that tracks dependencies between nodes and manages + their execution and caching accordingly. It extends the BasicCache class. + Nodes are removed from this cache once all of their descendants have been + executed. + """ + + def __init__(self, key_class): + """ + Initialize the DependencyAwareCache. + + Args: + key_class: The class used for generating cache keys. + """ + super().__init__(key_class) + self.descendants = {} # Maps node_id -> set of descendant node_ids + self.ancestors = {} # Maps node_id -> set of ancestor node_ids + self.executed_nodes = set() # Tracks nodes that have been executed + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + """ + Clear the entire cache and rebuild the dependency graph. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to initialize the cache for. + is_changed_cache: Flag indicating if the cache has changed. + """ + # Clear all existing cache data + self.cache.clear() + self.subcaches.clear() + self.descendants.clear() + self.ancestors.clear() + self.executed_nodes.clear() + + # Call the parent method to initialize the cache with the new prompt + super().set_prompt(dynprompt, node_ids, is_changed_cache) + + # Rebuild the dependency graph + self._build_dependency_graph(dynprompt, node_ids) + + def _build_dependency_graph(self, dynprompt, node_ids): + """ + Build the dependency graph for all nodes. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to build the graph for. + """ + self.descendants.clear() + self.ancestors.clear() + for node_id in node_ids: + self.descendants[node_id] = set() + self.ancestors[node_id] = set() + + for node_id in node_ids: + inputs = dynprompt.get_node(node_id)["inputs"] + for input_data in inputs.values(): + if is_link(input_data): # Check if the input is a link to another node + ancestor_id = input_data[0] + self.descendants[ancestor_id].add(node_id) + self.ancestors[node_id].add(ancestor_id) + + def set(self, node_id, value): + """ + Mark a node as executed and store its value in the cache. + + Args: + node_id: The ID of the node to store. + value: The value to store for the node. + """ + self._set_immediate(node_id, value) + self.executed_nodes.add(node_id) + self._cleanup_ancestors(node_id) + + def get(self, node_id): + """ + Retrieve the cached value for a node. + + Args: + node_id: The ID of the node to retrieve. + + Returns: + The cached value for the node. + """ + return self._get_immediate(node_id) + + def ensure_subcache_for(self, node_id, children_ids): + """ + Ensure a subcache exists for a node and update dependencies. + + Args: + node_id: The ID of the parent node. + children_ids: List of child node IDs to associate with the parent node. + + Returns: + The subcache object for the node. + """ + subcache = super()._ensure_subcache(node_id, children_ids) + for child_id in children_ids: + self.descendants[node_id].add(child_id) + self.ancestors[child_id].add(node_id) + return subcache + + def _cleanup_ancestors(self, node_id): + """ + Check if ancestors of a node can be removed from the cache. + + Args: + node_id: The ID of the node whose ancestors are to be checked. + """ + for ancestor_id in self.ancestors.get(node_id, []): + if ancestor_id in self.executed_nodes: + # Remove ancestor if all its descendants have been executed + if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): + self._remove_node(ancestor_id) + + def _remove_node(self, node_id): + """ + Remove a node from the cache. + + Args: + node_id: The ID of the node to remove. + """ + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + del self.cache[cache_key] + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + del self.subcaches[subcache_key] + + def clean_unused(self): + """ + Clean up unused nodes. This is a no-op for this cache implementation. + """ + pass + + def recursive_debug_dump(self): + """ + Dump the cache and dependency graph for debugging. + + Returns: + A list containing the cache state and dependency graph. + """ + result = super().recursive_debug_dump() + result.append({ + "descendants": self.descendants, + "ancestors": self.ancestors, + "executed_nodes": list(self.executed_nodes), + }) + return result diff --git a/execution.py b/execution.py index 41686888f..7431c100d 100644 --- a/execution.py +++ b/execution.py @@ -15,7 +15,7 @@ import nodes import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.validation import validate_node_input class ExecutionResult(Enum): @@ -60,26 +60,32 @@ class IsChangedCache: return self.is_changed[node_id] class CacheSet: - def __init__(self, lru_size=None): - if lru_size is None or lru_size == 0: + def __init__(self, lru_size=None, cache_none=False): + if cache_none: + self.init_dependency_aware_cache() + elif lru_size is None or lru_size == 0: self.init_classic_cache() else: self.init_lru_cache(lru_size) self.all = [self.outputs, self.ui, self.objects] - # Useful for those with ample RAM/VRAM -- allows experimenting without - # blowing away the cache every time - def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.objects = HierarchicalCache(CacheKeySetID) - # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # only hold cached items while the decendents have not executed + def init_dependency_aware_cache(self): + self.outputs = DependencyAwareCache(CacheKeySetInputSignature) + self.ui = DependencyAwareCache(CacheKeySetInputSignature) + self.objects = DependencyAwareCache(CacheKeySetID) + def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), @@ -414,13 +420,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None): + def __init__(self, server, lru_size=None, cache_none=False): self.lru_size = lru_size + self.cache_none = cache_none self.server = server self.reset() def reset(self): - self.caches = CacheSet(self.lru_size) + self.caches = CacheSet(self.lru_size, self.cache_none) self.status_messages = [] self.success = True diff --git a/main.py b/main.py index 1b100fa8a..e72e7c567 100644 --- a/main.py +++ b/main.py @@ -156,7 +156,7 @@ def cuda_malloc_warning(): def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 diff --git a/requirements.txt b/requirements.txt index 806fbc751..851db23bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.14.6 +comfyui-frontend-package==1.15.13 torch torchsde torchvision diff --git a/server.py b/server.py index 95092d595..62667ce18 100644 --- a/server.py +++ b/server.py @@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message): @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css'): + if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): response.headers.setdefault('Cache-Control', 'no-cache') return response