mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-22 15:37:18 +08:00

* Support for async execution functions This commit adds support for node execution functions defined as async. When a node's execution function is defined as async, we can continue executing other nodes while it is processing. Standard uses of `await` should "just work", but people will still have to be careful if they spawn actual threads. Because torch doesn't really have async/await versions of functions, this won't particularly help with most locally-executing nodes, but it does work for e.g. web requests to other machines. In addition to the execute function, the `VALIDATE_INPUTS` and `check_lazy_status` functions can also be defined as async, though we'll only resolve one node at a time right now for those. * Add the execution model tests to CI * Add a missing file It looks like this got caught by .gitignore? There's probably a better place to put it, but I'm not sure what that is. * Add the websocket library for automated tests * Add additional tests for async error cases Also fixes one bug that was found when an async function throws an error after being scheduled on a task. * Add a feature flags message to reduce bandwidth We now only send 1 preview message of the latest type the client can support. We'll add a console warning when the client fails to send a feature flags message at some point in the future. * Add async tests to CI * Don't actually add new tests in this PR Will do it in a separate PR * Resolve unit test in GPU-less runner * Just remove the tests that GHA can't handle * Change line endings to UNIX-style * Avoid loading model_management.py so early Because model_management.py has a top-level `logging.info`, we have to be careful not to import that file before we call `setup_logging`. If we do, we end up having the default logging handler registered in addition to our custom one.
361 lines
14 KiB
Python
361 lines
14 KiB
Python
import comfy.options
|
|
comfy.options.enable_args_parsing()
|
|
|
|
import os
|
|
import importlib.util
|
|
import folder_paths
|
|
import time
|
|
from comfy.cli_args import args
|
|
from app.logger import setup_logger
|
|
import itertools
|
|
import utils.extra_config
|
|
import logging
|
|
import sys
|
|
from comfy_execution.progress import get_progress_state
|
|
from comfy_execution.utils import get_executing_context
|
|
from comfy_api import feature_flags
|
|
|
|
if __name__ == "__main__":
|
|
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
|
os.environ['DO_NOT_TRACK'] = '1'
|
|
|
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
|
|
|
def apply_custom_paths():
|
|
# extra model paths
|
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
|
if os.path.isfile(extra_model_paths_config_path):
|
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
|
|
|
if args.extra_model_paths_config:
|
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
|
utils.extra_config.load_extra_path_config(config_path)
|
|
|
|
# --output-directory, --input-directory, --user-directory
|
|
if args.output_directory:
|
|
output_dir = os.path.abspath(args.output_directory)
|
|
logging.info(f"Setting output directory to: {output_dir}")
|
|
folder_paths.set_output_directory(output_dir)
|
|
|
|
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
|
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
|
folder_paths.add_model_folder_path("diffusion_models",
|
|
os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
|
|
|
if args.input_directory:
|
|
input_dir = os.path.abspath(args.input_directory)
|
|
logging.info(f"Setting input directory to: {input_dir}")
|
|
folder_paths.set_input_directory(input_dir)
|
|
|
|
if args.user_directory:
|
|
user_dir = os.path.abspath(args.user_directory)
|
|
logging.info(f"Setting user directory to: {user_dir}")
|
|
folder_paths.set_user_directory(user_dir)
|
|
|
|
|
|
def execute_prestartup_script():
|
|
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
|
|
return
|
|
|
|
def execute_script(script_path):
|
|
module_name = os.path.splitext(script_path)[0]
|
|
try:
|
|
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
|
return False
|
|
|
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
|
for custom_node_path in node_paths:
|
|
possible_modules = os.listdir(custom_node_path)
|
|
node_prestartup_times = []
|
|
|
|
for possible_module in possible_modules:
|
|
module_path = os.path.join(custom_node_path, possible_module)
|
|
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
|
continue
|
|
|
|
script_path = os.path.join(module_path, "prestartup_script.py")
|
|
if os.path.exists(script_path):
|
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
|
logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
|
continue
|
|
time_before = time.perf_counter()
|
|
success = execute_script(script_path)
|
|
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
|
if len(node_prestartup_times) > 0:
|
|
logging.info("\nPrestartup times for custom nodes:")
|
|
for n in sorted(node_prestartup_times):
|
|
if n[2]:
|
|
import_message = ""
|
|
else:
|
|
import_message = " (PRESTARTUP FAILED)"
|
|
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
|
logging.info("")
|
|
|
|
apply_custom_paths()
|
|
execute_prestartup_script()
|
|
|
|
|
|
# Main code
|
|
import asyncio
|
|
import shutil
|
|
import threading
|
|
import gc
|
|
|
|
|
|
if os.name == "nt":
|
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
|
|
|
if __name__ == "__main__":
|
|
if args.cuda_device is not None:
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
|
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
|
|
|
if args.oneapi_device_selector is not None:
|
|
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
|
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
|
|
|
if args.deterministic:
|
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
|
|
|
import cuda_malloc
|
|
|
|
if 'torch' in sys.modules:
|
|
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
|
|
|
import comfy.utils
|
|
|
|
import execution
|
|
import server
|
|
from protocol import BinaryEventTypes
|
|
import nodes
|
|
import comfy.model_management
|
|
import comfyui_version
|
|
import app.logger
|
|
import hook_breaker_ac10a0
|
|
|
|
def cuda_malloc_warning():
|
|
device = comfy.model_management.get_torch_device()
|
|
device_name = comfy.model_management.get_torch_device_name(device)
|
|
cuda_malloc_warning = False
|
|
if "cudaMallocAsync" in device_name:
|
|
for b in cuda_malloc.blacklist:
|
|
if b in device_name:
|
|
cuda_malloc_warning = True
|
|
if cuda_malloc_warning:
|
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
|
|
|
|
|
def prompt_worker(q, server_instance):
|
|
current_time: float = 0.0
|
|
cache_type = execution.CacheType.CLASSIC
|
|
if args.cache_lru > 0:
|
|
cache_type = execution.CacheType.LRU
|
|
elif args.cache_none:
|
|
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
|
|
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
|
last_gc_collect = 0
|
|
need_gc = False
|
|
gc_collect_interval = 10.0
|
|
|
|
while True:
|
|
timeout = 1000.0
|
|
if need_gc:
|
|
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
|
|
|
queue_item = q.get(timeout=timeout)
|
|
if queue_item is not None:
|
|
item, item_id = queue_item
|
|
execution_start_time = time.perf_counter()
|
|
prompt_id = item[1]
|
|
server_instance.last_prompt_id = prompt_id
|
|
|
|
e.execute(item[2], prompt_id, item[3], item[4])
|
|
need_gc = True
|
|
q.task_done(item_id,
|
|
e.history_result,
|
|
status=execution.PromptQueue.ExecutionStatus(
|
|
status_str='success' if e.success else 'error',
|
|
completed=e.success,
|
|
messages=e.status_messages))
|
|
if server_instance.client_id is not None:
|
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
|
|
|
current_time = time.perf_counter()
|
|
execution_time = current_time - execution_start_time
|
|
|
|
# Log Time in a more readable way after 10 minutes
|
|
if execution_time > 600:
|
|
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
|
|
logging.info(f"Prompt executed in {execution_time}")
|
|
else:
|
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
|
|
|
flags = q.get_flags()
|
|
free_memory = flags.get("free_memory", False)
|
|
|
|
if flags.get("unload_models", free_memory):
|
|
comfy.model_management.unload_all_models()
|
|
need_gc = True
|
|
last_gc_collect = 0
|
|
|
|
if free_memory:
|
|
e.reset()
|
|
need_gc = True
|
|
last_gc_collect = 0
|
|
|
|
if need_gc:
|
|
current_time = time.perf_counter()
|
|
if (current_time - last_gc_collect) > gc_collect_interval:
|
|
gc.collect()
|
|
comfy.model_management.soft_empty_cache()
|
|
last_gc_collect = current_time
|
|
need_gc = False
|
|
hook_breaker_ac10a0.restore_functions()
|
|
|
|
|
|
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
|
addresses = []
|
|
for addr in address.split(","):
|
|
addresses.append((addr, port))
|
|
await asyncio.gather(
|
|
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
|
)
|
|
|
|
def hijack_progress(server_instance):
|
|
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
|
executing_context = get_executing_context()
|
|
if prompt_id is None and executing_context is not None:
|
|
prompt_id = executing_context.prompt_id
|
|
if node_id is None and executing_context is not None:
|
|
node_id = executing_context.node_id
|
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
if prompt_id is None:
|
|
prompt_id = server_instance.last_prompt_id
|
|
if node_id is None:
|
|
node_id = server_instance.last_node_id
|
|
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
|
get_progress_state().update_progress(node_id, value, total, preview_image)
|
|
|
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
|
if preview_image is not None:
|
|
# Only send old method if client doesn't support preview metadata
|
|
if not feature_flags.supports_feature(
|
|
server_instance.sockets_metadata,
|
|
server_instance.client_id,
|
|
"supports_preview_metadata",
|
|
):
|
|
server_instance.send_sync(
|
|
BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
|
|
preview_image,
|
|
server_instance.client_id,
|
|
)
|
|
|
|
comfy.utils.set_progress_bar_global_hook(hook)
|
|
|
|
|
|
def cleanup_temp():
|
|
temp_dir = folder_paths.get_temp_directory()
|
|
if os.path.exists(temp_dir):
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
|
|
def setup_database():
|
|
try:
|
|
from app.database.db import init_db, dependencies_available
|
|
if dependencies_available():
|
|
init_db()
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
|
|
|
|
|
def start_comfyui(asyncio_loop=None):
|
|
"""
|
|
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
|
Returns the event loop, server instance, and a function to start the server asynchronously.
|
|
"""
|
|
if args.temp_directory:
|
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
|
logging.info(f"Setting temp directory to: {temp_dir}")
|
|
folder_paths.set_temp_directory(temp_dir)
|
|
cleanup_temp()
|
|
|
|
if args.windows_standalone_build:
|
|
try:
|
|
import new_updater
|
|
new_updater.update_windows_updater()
|
|
except:
|
|
pass
|
|
|
|
if not asyncio_loop:
|
|
asyncio_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(asyncio_loop)
|
|
prompt_server = server.PromptServer(asyncio_loop)
|
|
|
|
hook_breaker_ac10a0.save_functions()
|
|
nodes.init_extra_nodes(
|
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
|
init_api_nodes=not args.disable_api_nodes
|
|
)
|
|
hook_breaker_ac10a0.restore_functions()
|
|
|
|
cuda_malloc_warning()
|
|
setup_database()
|
|
|
|
prompt_server.add_routes()
|
|
hijack_progress(prompt_server)
|
|
|
|
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
|
|
|
if args.quick_test_for_ci:
|
|
exit(0)
|
|
|
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
|
call_on_start = None
|
|
if args.auto_launch:
|
|
def startup_server(scheme, address, port):
|
|
import webbrowser
|
|
if os.name == 'nt' and address == '0.0.0.0':
|
|
address = '127.0.0.1'
|
|
if ':' in address:
|
|
address = "[{}]".format(address)
|
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
|
call_on_start = startup_server
|
|
|
|
async def start_all():
|
|
await prompt_server.setup()
|
|
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
|
|
|
# Returning these so that other code can integrate with the ComfyUI loop and server
|
|
return asyncio_loop, prompt_server, start_all
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Running directly, just start ComfyUI.
|
|
logging.info("Python version: {}".format(sys.version))
|
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
|
|
|
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
|
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
|
|
|
event_loop, _, start_all_func = start_comfyui()
|
|
try:
|
|
x = start_all_func()
|
|
app.logger.print_startup_warnings()
|
|
event_loop.run_until_complete(x)
|
|
except KeyboardInterrupt:
|
|
logging.info("\nStopped server")
|
|
|
|
cleanup_temp()
|