Enable External Event Loop Integration for ComfyUI [refactor] (#6114)

* Refactor main.py to support external event loop integration

* added optional "asyncio_loop" argument to allow using existing event loop

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
Alexander Piskun 2024-12-24 14:38:52 +03:00 committed by GitHub
parent bc6dac4327
commit 26e0ba8f8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

60
main.py
View File

@ -150,9 +150,10 @@ def cuda_malloc_warning():
if cuda_malloc_warning: if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
e = execution.PromptExecutor(server, lru_size=args.cache_lru) e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0
@ -167,7 +168,7 @@ def prompt_worker(q, server):
item, item_id = queue_item item, item_id = queue_item
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
prompt_id = item[1] prompt_id = item[1]
server.last_prompt_id = prompt_id server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
@ -177,8 +178,8 @@ def prompt_worker(q, server):
status_str='success' if e.success else 'error', status_str='success' if e.success else 'error',
completed=e.success, completed=e.success,
messages=e.status_messages)) messages=e.status_messages))
if server.client_id is not None: if server_instance.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
@ -205,21 +206,23 @@ def prompt_worker(q, server):
last_gc_collect = current_time last_gc_collect = current_time
need_gc = False need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
addresses = [] addresses = []
for addr in address.split(","): for addr in address.split(","):
addresses.append((addr, port)) addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop()) await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
def hijack_progress(server): def hijack_progress(server_instance):
def hook(value, total, preview_image): def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
server.send_sync("progress", progress, server.client_id) server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)
@ -229,7 +232,11 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__": def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
Returns the event loop, server instance, and a function to start the server asynchronously.
"""
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logging.info(f"Setting temp directory to: {temp_dir}") logging.info(f"Setting temp directory to: {temp_dir}")
@ -243,19 +250,20 @@ if __name__ == "__main__":
except: except:
pass pass
loop = asyncio.new_event_loop() if not asyncio_loop:
asyncio.set_event_loop(loop) asyncio_loop = asyncio.new_event_loop()
server = server.PromptServer(loop) asyncio.set_event_loop(asyncio_loop)
q = execution.PromptQueue(server) prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
cuda_malloc_warning() cuda_malloc_warning()
server.add_routes() prompt_server.add_routes()
hijack_progress(server) hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
if args.quick_test_for_ci: if args.quick_test_for_ci:
exit(0) exit(0)
@ -272,9 +280,19 @@ if __name__ == "__main__":
webbrowser.open(f"{scheme}://{address}:{port}") webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server
async def start_all():
await prompt_server.setup()
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
# Returning these so that other code can integrate with the ComfyUI loop and server
return asyncio_loop, prompt_server, start_all
if __name__ == "__main__":
# Running directly, just start ComfyUI.
event_loop, _, start_all_func = start_comfyui()
try: try:
loop.run_until_complete(server.setup()) event_loop.run_until_complete(start_all_func())
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("\nStopped server") logging.info("\nStopped server")