diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 52930652b..771fc1655 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -102,6 +102,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") +cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 1b71208d4..3535966fb 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -102,9 +102,13 @@ class InputTypeOptions(TypedDict): default: bool | str | float | int | list | tuple """The default value of the widget""" defaultInput: bool - """Defaults to an input slot rather than a widget""" + """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist. + - defaultInput on required inputs should be dropped. + - defaultInput on optional inputs should be replaced with forceInput. + Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364 + """ forceInput: bool - """`defaultInput` and also don't allow converting to a widget""" + """Forces the input to be an input slot rather than a widget even a widget is available for the input type.""" lazy: bool """Declares that this input uses lazy evaluation""" rawLink: bool diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 630f280fc..dbb37b89f 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -316,3 +316,156 @@ class LRUCache(BasicCache): self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + +class DependencyAwareCache(BasicCache): + """ + A cache implementation that tracks dependencies between nodes and manages + their execution and caching accordingly. It extends the BasicCache class. + Nodes are removed from this cache once all of their descendants have been + executed. + """ + + def __init__(self, key_class): + """ + Initialize the DependencyAwareCache. + + Args: + key_class: The class used for generating cache keys. + """ + super().__init__(key_class) + self.descendants = {} # Maps node_id -> set of descendant node_ids + self.ancestors = {} # Maps node_id -> set of ancestor node_ids + self.executed_nodes = set() # Tracks nodes that have been executed + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + """ + Clear the entire cache and rebuild the dependency graph. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to initialize the cache for. + is_changed_cache: Flag indicating if the cache has changed. + """ + # Clear all existing cache data + self.cache.clear() + self.subcaches.clear() + self.descendants.clear() + self.ancestors.clear() + self.executed_nodes.clear() + + # Call the parent method to initialize the cache with the new prompt + super().set_prompt(dynprompt, node_ids, is_changed_cache) + + # Rebuild the dependency graph + self._build_dependency_graph(dynprompt, node_ids) + + def _build_dependency_graph(self, dynprompt, node_ids): + """ + Build the dependency graph for all nodes. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to build the graph for. + """ + self.descendants.clear() + self.ancestors.clear() + for node_id in node_ids: + self.descendants[node_id] = set() + self.ancestors[node_id] = set() + + for node_id in node_ids: + inputs = dynprompt.get_node(node_id)["inputs"] + for input_data in inputs.values(): + if is_link(input_data): # Check if the input is a link to another node + ancestor_id = input_data[0] + self.descendants[ancestor_id].add(node_id) + self.ancestors[node_id].add(ancestor_id) + + def set(self, node_id, value): + """ + Mark a node as executed and store its value in the cache. + + Args: + node_id: The ID of the node to store. + value: The value to store for the node. + """ + self._set_immediate(node_id, value) + self.executed_nodes.add(node_id) + self._cleanup_ancestors(node_id) + + def get(self, node_id): + """ + Retrieve the cached value for a node. + + Args: + node_id: The ID of the node to retrieve. + + Returns: + The cached value for the node. + """ + return self._get_immediate(node_id) + + def ensure_subcache_for(self, node_id, children_ids): + """ + Ensure a subcache exists for a node and update dependencies. + + Args: + node_id: The ID of the parent node. + children_ids: List of child node IDs to associate with the parent node. + + Returns: + The subcache object for the node. + """ + subcache = super()._ensure_subcache(node_id, children_ids) + for child_id in children_ids: + self.descendants[node_id].add(child_id) + self.ancestors[child_id].add(node_id) + return subcache + + def _cleanup_ancestors(self, node_id): + """ + Check if ancestors of a node can be removed from the cache. + + Args: + node_id: The ID of the node whose ancestors are to be checked. + """ + for ancestor_id in self.ancestors.get(node_id, []): + if ancestor_id in self.executed_nodes: + # Remove ancestor if all its descendants have been executed + if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): + self._remove_node(ancestor_id) + + def _remove_node(self, node_id): + """ + Remove a node from the cache. + + Args: + node_id: The ID of the node to remove. + """ + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + del self.cache[cache_key] + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + del self.subcaches[subcache_key] + + def clean_unused(self): + """ + Clean up unused nodes. This is a no-op for this cache implementation. + """ + pass + + def recursive_debug_dump(self): + """ + Dump the cache and dependency graph for debugging. + + Returns: + A list containing the cache state and dependency graph. + """ + result = super().recursive_debug_dump() + result.append({ + "descendants": self.descendants, + "ancestors": self.ancestors, + "executed_nodes": list(self.executed_nodes), + }) + return result diff --git a/execution.py b/execution.py index 41686888f..9a5e27771 100644 --- a/execution.py +++ b/execution.py @@ -15,7 +15,7 @@ import nodes import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.validation import validate_node_input class ExecutionResult(Enum): @@ -59,20 +59,27 @@ class IsChangedCache: self.is_changed[node_id] = node["is_changed"] return self.is_changed[node_id] -class CacheSet: - def __init__(self, lru_size=None): - if lru_size is None or lru_size == 0: - self.init_classic_cache() - else: - self.init_lru_cache(lru_size) - self.all = [self.outputs, self.ui, self.objects] - # Useful for those with ample RAM/VRAM -- allows experimenting without - # blowing away the cache every time - def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.objects = HierarchicalCache(CacheKeySetID) +class CacheType(Enum): + CLASSIC = 0 + LRU = 1 + DEPENDENCY_AWARE = 2 + + +class CacheSet: + def __init__(self, cache_type=None, cache_size=None): + if cache_type == CacheType.DEPENDENCY_AWARE: + self.init_dependency_aware_cache() + logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.LRU: + if cache_size is None: + cache_size = 0 + self.init_lru_cache(cache_size) + logging.info("Using LRU cache") + else: + self.init_classic_cache() + + self.all = [self.outputs, self.ui, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): @@ -80,6 +87,17 @@ class CacheSet: self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # only hold cached items while the decendents have not executed + def init_dependency_aware_cache(self): + self.outputs = DependencyAwareCache(CacheKeySetInputSignature) + self.ui = DependencyAwareCache(CacheKeySetInputSignature) + self.objects = DependencyAwareCache(CacheKeySetID) + def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), @@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None): - self.lru_size = lru_size + def __init__(self, server, cache_type=False, cache_size=None): + self.cache_size = cache_size + self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(self.lru_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.status_messages = [] self.success = True diff --git a/main.py b/main.py index 3ab31f414..7ca04f4b1 100644 --- a/main.py +++ b/main.py @@ -159,7 +159,13 @@ def cuda_malloc_warning(): def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) + cache_type = execution.CacheType.CLASSIC + if args.cache_lru > 0: + cache_type = execution.CacheType.LRU + elif args.cache_none: + cache_type = execution.CacheType.DEPENDENCY_AWARE + + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 diff --git a/nodes.py b/nodes.py index f63e8cb5e..8c1720c1a 100644 --- a/nodes.py +++ b/nodes.py @@ -2136,7 +2136,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes module_name = sp[0] sys_module_name = module_name elif os.path.isdir(module_path): - sys_module_name = module_path + sys_module_name = module_path.replace(".", "_x_") try: logging.debug("Trying to load custom node {}".format(module_path)) diff --git a/requirements.txt b/requirements.txt index 806fbc751..851db23bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.14.6 +comfyui-frontend-package==1.15.13 torch torchsde torchvision diff --git a/server.py b/server.py index 95092d595..62667ce18 100644 --- a/server.py +++ b/server.py @@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message): @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css'): + if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): response.headers.setdefault('Cache-Control', 'no-cache') return response