diff --git a/main.py b/main.py index 151b264c..ccc99fdc 100644 --- a/main.py +++ b/main.py @@ -150,9 +150,10 @@ def cuda_malloc_warning(): if cuda_malloc_warning: logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def prompt_worker(q, server): + +def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server, lru_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -167,7 +168,7 @@ def prompt_worker(q, server): item, item_id = queue_item execution_start_time = time.perf_counter() prompt_id = item[1] - server.last_prompt_id = prompt_id + server_instance.last_prompt_id = prompt_id e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True @@ -177,8 +178,8 @@ def prompt_worker(q, server): status_str='success' if e.success else 'error', completed=e.success, messages=e.status_messages)) - if server.client_id is not None: - server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) + if server_instance.client_id is not None: + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -205,21 +206,23 @@ def prompt_worker(q, server): last_gc_collect = current_time need_gc = False -async def run(server, address='', port=8188, verbose=True, call_on_start=None): + +async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): addresses = [] for addr in address.split(","): addresses.append((addr, port)) - await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop()) + await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop()) -def hijack_progress(server): +def hijack_progress(server_instance): def hook(value, total, preview_image): comfy.model_management.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} - server.send_sync("progress", progress, server.client_id) + server_instance.send_sync("progress", progress, server_instance.client_id) if preview_image is not None: - server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) + server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) + comfy.utils.set_progress_bar_global_hook(hook) @@ -229,7 +232,11 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) -if __name__ == "__main__": +def start_comfyui(asyncio_loop=None): + """ + Starts the ComfyUI server using the provided asyncio event loop or creates a new one. + Returns the event loop, server instance, and a function to start the server asynchronously. + """ if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") logging.info(f"Setting temp directory to: {temp_dir}") @@ -243,19 +250,20 @@ if __name__ == "__main__": except: pass - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - server = server.PromptServer(loop) - q = execution.PromptQueue(server) + if not asyncio_loop: + asyncio_loop = asyncio.new_event_loop() + asyncio.set_event_loop(asyncio_loop) + prompt_server = server.PromptServer(asyncio_loop) + q = execution.PromptQueue(prompt_server) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) cuda_malloc_warning() - server.add_routes() - hijack_progress(server) + prompt_server.add_routes() + hijack_progress(prompt_server) - threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start() if args.quick_test_for_ci: exit(0) @@ -272,9 +280,19 @@ if __name__ == "__main__": webbrowser.open(f"{scheme}://{address}:{port}") call_on_start = startup_server + async def start_all(): + await prompt_server.setup() + await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) + + # Returning these so that other code can integrate with the ComfyUI loop and server + return asyncio_loop, prompt_server, start_all + + +if __name__ == "__main__": + # Running directly, just start ComfyUI. + event_loop, _, start_all_func = start_comfyui() try: - loop.run_until_complete(server.setup()) - loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) + event_loop.run_until_complete(start_all_func()) except KeyboardInterrupt: logging.info("\nStopped server")