Compare commits

...

11 Commits

Author SHA1 Message Date
Hayden
21df85136d
Merge 9c957977d0 into 22ad513c72 2025-04-11 09:49:48 -04:00
comfyanonymous
22ad513c72 Refactor node cache code to more easily add other types of cache. 2025-04-11 07:16:52 -04:00
Chargeuk
ed945a1790
Dependency Aware Node Caching for low RAM/VRAM machines (#7509)
* add dependency aware cache that removed a cached node as soon as all of its decendents have executed. This allows users with lower RAM to run workflows they would otherwise not be able to run. The downside is that every workflow will fully run each time even if no nodes have changed.

* remove test code

* tidy code
2025-04-11 06:55:51 -04:00
Chenlei Hu
f9207c6936
Update frontend to 1.15 (#7564) 2025-04-11 06:46:20 -04:00
Christian Byrne
8ad7477647
dont cache templates index (#7569) 2025-04-11 06:06:53 -04:00
hayden
9c957977d0 Repair the file operation abnormality and add a log record 2025-01-24 16:58:40 +08:00
hayden
2cf95ed231 Change property name 2025-01-24 14:37:34 +08:00
hayden
d8eae1b241 Add folder content caching 2025-01-23 13:53:41 +08:00
hayden
82c3afe077 Optimize file scanning performance using multi-threading 2025-01-23 13:09:07 +08:00
hayden
0e86405198 Optimize file scanning using scandir 2025-01-23 12:53:22 +08:00
hayden
7bba21af47 Add output manager 2025-01-22 15:58:33 +08:00
7 changed files with 358 additions and 20 deletions

156
app/output_manager.py Normal file
View File

@ -0,0 +1,156 @@
from __future__ import annotations
import os
import logging
import folder_paths
import mimetypes
import shutil
import traceback
from aiohttp import web
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from PIL import Image
from typing import Literal
class OutputManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list, float]] = {}
self.output_uri = folder_paths.get_output_directory()
def get_cache(self, key: str):
return self.cache.get(key, ([], 0))
def set_cache(self, key: str, value: tuple[list, float]):
self.cache[key] = value
def rm_cache(self, key: str):
if key in self.cache:
del self.cache[key]
def add_routes(self, routes) -> None:
@routes.get("/output{pathname:.*}")
async def get_output_file_or_files(request):
pathname = request.match_info.get("pathname", None)
try:
filepath = self.get_output_filepath(pathname)
if os.path.isfile(filepath):
preview_type = request.query.get("preview_type", None)
if not preview_type:
return web.FileResponse(filepath)
# get image preview
if self.assert_file_type(filepath, ["image"]):
image_data = self.get_image_preview_data(filepath)
return web.Response(body=image_data.getvalue(), content_type="image/webp")
# TODO get video cover preview
elif os.path.isdir(filepath):
files = self.get_folder_items(filepath)
return web.json_response(files)
return web.Response(status=404)
except Exception:
logging.error(f"File '{pathname}' retrieval failed")
logging.error(traceback.format_exc())
return web.Response(status=500)
@routes.delete("/output{pathname:.*}")
async def delete_output_file_or_files(request):
pathname = request.match_info.get("pathname", None)
try:
filepath = self.get_output_filepath(pathname)
if os.path.isfile(filepath):
os.remove(filepath)
elif os.path.isdir(filepath):
shutil.rmtree(filepath)
self.rm_cache(filepath)
return web.Response(status=200)
except Exception:
logging.error(f"File '{pathname}' deletion failed")
logging.error(traceback.format_exc())
return web.Response(status=500)
def get_output_filepath(self, pathname: str):
return f"{self.output_uri}/{pathname}"
def get_folder_items(self, folder: str):
result, m_time = self.get_cache(folder)
folder_m_time = os.path.getmtime(folder)
if folder_m_time == m_time:
return result
result = []
def get_file_info(entry: os.DirEntry[str]):
filepath = entry.path
is_dir = entry.is_dir()
if not is_dir and not self.assert_file_type(filepath, ["image", "video", "audio"]):
return None
stat = entry.stat()
return {
"name": entry.name,
"type": "folder" if entry.is_dir() else self.get_file_content_type(filepath),
"size": 0 if is_dir else stat.st_size,
"createTime": round(stat.st_ctime_ns / 1000000),
"modifyTime": round(stat.st_mtime_ns / 1000000),
}
with os.scandir(folder) as it, ThreadPoolExecutor() as executor:
future_to_entry = {executor.submit(get_file_info, entry): entry for entry in it}
for future in as_completed(future_to_entry):
file_info = future.result()
if file_info is None:
continue
result.append(file_info)
self.set_cache(folder, (result, os.path.getmtime(folder)))
return result
def assert_file_type(self, filename: str, content_types: Literal["image", "video", "audio"]):
content_type = self.get_file_content_type(filename)
if not content_type:
return False
return content_type in content_types
def get_file_content_type(self, filename: str):
extension_mimetypes_cache = folder_paths.extension_mimetypes_cache
extension = filename.split(".")[-1]
content_type = None
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(filename, strict=False)
if mime_type:
content_type = mime_type.split("/")[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
return content_type
def get_image_preview_data(self, filename: str):
with Image.open(filename) as img:
max_size = 128
old_width, old_height = img.size
scale = min(max_size / old_width, max_size / old_height)
if scale >= 1:
new_width, new_height = old_width, old_height
else:
new_width = int(old_width * scale)
new_height = int(old_height * scale)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="WEBP")
img_byte_arr.seek(0)
return img_byte_arr

View File

@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -316,3 +316,156 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

View File

@ -15,7 +15,7 @@ import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
@ -59,20 +59,27 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
@ -80,6 +87,17 @@ class CacheSet:
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(self.lru_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.status_messages = []
self.success = True

View File

@ -156,7 +156,13 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.14.6
comfyui-frontend-package==1.15.13
torch
torchsde
torchvision

View File

@ -31,6 +31,7 @@ from comfyui_version import __version__
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.output_manager import OutputManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
@ -48,7 +49,7 @@ async def send_socket_catch_exception(function, message):
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css'):
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@ -155,6 +156,7 @@ class PromptServer():
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.output_manager = OutputManager()
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
@ -715,6 +717,7 @@ class PromptServer():
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.output_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())