mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 00:23:30 +00:00
Compare commits
11 Commits
830d5486ab
...
21df85136d
Author | SHA1 | Date | |
---|---|---|---|
![]() |
21df85136d | ||
![]() |
22ad513c72 | ||
![]() |
ed945a1790 | ||
![]() |
f9207c6936 | ||
![]() |
8ad7477647 | ||
![]() |
9c957977d0 | ||
![]() |
2cf95ed231 | ||
![]() |
d8eae1b241 | ||
![]() |
82c3afe077 | ||
![]() |
0e86405198 | ||
![]() |
7bba21af47 |
156
app/output_manager.py
Normal file
156
app/output_manager.py
Normal file
@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
import folder_paths
|
||||
import mimetypes
|
||||
import shutil
|
||||
import traceback
|
||||
from aiohttp import web
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class OutputManager:
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, tuple[list, float]] = {}
|
||||
self.output_uri = folder_paths.get_output_directory()
|
||||
|
||||
def get_cache(self, key: str):
|
||||
return self.cache.get(key, ([], 0))
|
||||
|
||||
def set_cache(self, key: str, value: tuple[list, float]):
|
||||
self.cache[key] = value
|
||||
|
||||
def rm_cache(self, key: str):
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
|
||||
def add_routes(self, routes) -> None:
|
||||
@routes.get("/output{pathname:.*}")
|
||||
async def get_output_file_or_files(request):
|
||||
pathname = request.match_info.get("pathname", None)
|
||||
try:
|
||||
filepath = self.get_output_filepath(pathname)
|
||||
|
||||
if os.path.isfile(filepath):
|
||||
|
||||
preview_type = request.query.get("preview_type", None)
|
||||
if not preview_type:
|
||||
return web.FileResponse(filepath)
|
||||
|
||||
# get image preview
|
||||
if self.assert_file_type(filepath, ["image"]):
|
||||
image_data = self.get_image_preview_data(filepath)
|
||||
return web.Response(body=image_data.getvalue(), content_type="image/webp")
|
||||
|
||||
# TODO get video cover preview
|
||||
|
||||
elif os.path.isdir(filepath):
|
||||
files = self.get_folder_items(filepath)
|
||||
return web.json_response(files)
|
||||
|
||||
return web.Response(status=404)
|
||||
except Exception:
|
||||
logging.error(f"File '{pathname}' retrieval failed")
|
||||
logging.error(traceback.format_exc())
|
||||
return web.Response(status=500)
|
||||
|
||||
@routes.delete("/output{pathname:.*}")
|
||||
async def delete_output_file_or_files(request):
|
||||
pathname = request.match_info.get("pathname", None)
|
||||
try:
|
||||
filepath = self.get_output_filepath(pathname)
|
||||
|
||||
if os.path.isfile(filepath):
|
||||
os.remove(filepath)
|
||||
elif os.path.isdir(filepath):
|
||||
shutil.rmtree(filepath)
|
||||
self.rm_cache(filepath)
|
||||
return web.Response(status=200)
|
||||
except Exception:
|
||||
logging.error(f"File '{pathname}' deletion failed")
|
||||
logging.error(traceback.format_exc())
|
||||
return web.Response(status=500)
|
||||
|
||||
def get_output_filepath(self, pathname: str):
|
||||
return f"{self.output_uri}/{pathname}"
|
||||
|
||||
def get_folder_items(self, folder: str):
|
||||
result, m_time = self.get_cache(folder)
|
||||
folder_m_time = os.path.getmtime(folder)
|
||||
|
||||
if folder_m_time == m_time:
|
||||
return result
|
||||
|
||||
result = []
|
||||
|
||||
def get_file_info(entry: os.DirEntry[str]):
|
||||
filepath = entry.path
|
||||
is_dir = entry.is_dir()
|
||||
|
||||
if not is_dir and not self.assert_file_type(filepath, ["image", "video", "audio"]):
|
||||
return None
|
||||
|
||||
stat = entry.stat()
|
||||
return {
|
||||
"name": entry.name,
|
||||
"type": "folder" if entry.is_dir() else self.get_file_content_type(filepath),
|
||||
"size": 0 if is_dir else stat.st_size,
|
||||
"createTime": round(stat.st_ctime_ns / 1000000),
|
||||
"modifyTime": round(stat.st_mtime_ns / 1000000),
|
||||
}
|
||||
|
||||
with os.scandir(folder) as it, ThreadPoolExecutor() as executor:
|
||||
future_to_entry = {executor.submit(get_file_info, entry): entry for entry in it}
|
||||
for future in as_completed(future_to_entry):
|
||||
file_info = future.result()
|
||||
if file_info is None:
|
||||
continue
|
||||
result.append(file_info)
|
||||
|
||||
self.set_cache(folder, (result, os.path.getmtime(folder)))
|
||||
return result
|
||||
|
||||
def assert_file_type(self, filename: str, content_types: Literal["image", "video", "audio"]):
|
||||
content_type = self.get_file_content_type(filename)
|
||||
if not content_type:
|
||||
return False
|
||||
return content_type in content_types
|
||||
|
||||
def get_file_content_type(self, filename: str):
|
||||
extension_mimetypes_cache = folder_paths.extension_mimetypes_cache
|
||||
|
||||
extension = filename.split(".")[-1]
|
||||
content_type = None
|
||||
if extension not in extension_mimetypes_cache:
|
||||
mime_type, _ = mimetypes.guess_type(filename, strict=False)
|
||||
if mime_type:
|
||||
content_type = mime_type.split("/")[0]
|
||||
extension_mimetypes_cache[extension] = content_type
|
||||
else:
|
||||
content_type = extension_mimetypes_cache[extension]
|
||||
|
||||
return content_type
|
||||
|
||||
def get_image_preview_data(self, filename: str):
|
||||
with Image.open(filename) as img:
|
||||
max_size = 128
|
||||
|
||||
old_width, old_height = img.size
|
||||
scale = min(max_size / old_width, max_size / old_height)
|
||||
|
||||
if scale >= 1:
|
||||
new_width, new_height = old_width, old_height
|
||||
else:
|
||||
new_width = int(old_width * scale)
|
||||
new_height = int(old_height * scale)
|
||||
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="WEBP")
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
@ -316,3 +316,156 @@ class LRUCache(BasicCache):
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
|
53
execution.py
53
execution.py
@ -15,7 +15,7 @@ import nodes
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -59,20 +59,27 @@ class IsChangedCache:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
DEPENDENCY_AWARE = 2
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
cache_size = 0
|
||||
self.init_lru_cache(cache_size)
|
||||
logging.info("Using LRU cache")
|
||||
else:
|
||||
self.init_classic_cache()
|
||||
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
@ -80,6 +87,17 @@ class CacheSet:
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, lru_size=None):
|
||||
self.lru_size = lru_size
|
||||
def __init__(self, server, cache_type=False, cache_size=None):
|
||||
self.cache_size = cache_size
|
||||
self.cache_type = cache_type
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
|
8
main.py
8
main.py
@ -156,7 +156,13 @@ def cuda_malloc_warning():
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.14.6
|
||||
comfyui-frontend-package==1.15.13
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
@ -31,6 +31,7 @@ from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.output_manager import OutputManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
@ -48,7 +49,7 @@ async def send_socket_catch_exception(function, message):
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
if request.path.endswith('.js') or request.path.endswith('.css'):
|
||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
@ -155,6 +156,7 @@ class PromptServer():
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.output_manager = OutputManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
@ -715,6 +717,7 @@ class PromptServer():
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.output_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user