mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Enable External Event Loop Integration for ComfyUI [refactor] (#6114)
* Refactor main.py to support external event loop integration * added optional "asyncio_loop" argument to allow using existing event loop --------- Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
parent
bc6dac4327
commit
26e0ba8f8c
60
main.py
60
main.py
@ -150,9 +150,10 @@ def cuda_malloc_warning():
|
|||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
|
||||||
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@ -167,7 +168,7 @@ def prompt_worker(q, server):
|
|||||||
item, item_id = queue_item
|
item, item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
server.last_prompt_id = prompt_id
|
server_instance.last_prompt_id = prompt_id
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
@ -177,8 +178,8 @@ def prompt_worker(q, server):
|
|||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages))
|
messages=e.status_messages))
|
||||||
if server.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -205,21 +206,23 @@ def prompt_worker(q, server):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|
||||||
|
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
addresses = []
|
addresses = []
|
||||||
for addr in address.split(","):
|
for addr in address.split(","):
|
||||||
addresses.append((addr, port))
|
addresses.append((addr, port))
|
||||||
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server_instance):
|
||||||
def hook(value, total, preview_image):
|
def hook(value, total, preview_image):
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||||
|
|
||||||
server.send_sync("progress", progress, server.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||||
|
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
|
||||||
|
|
||||||
@ -229,7 +232,11 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def start_comfyui(asyncio_loop=None):
|
||||||
|
"""
|
||||||
|
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||||
|
Returns the event loop, server instance, and a function to start the server asynchronously.
|
||||||
|
"""
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
logging.info(f"Setting temp directory to: {temp_dir}")
|
logging.info(f"Setting temp directory to: {temp_dir}")
|
||||||
@ -243,19 +250,20 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
if not asyncio_loop:
|
||||||
asyncio.set_event_loop(loop)
|
asyncio_loop = asyncio.new_event_loop()
|
||||||
server = server.PromptServer(loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
q = execution.PromptQueue(server)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
q = execution.PromptQueue(prompt_server)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|
||||||
server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
@ -272,9 +280,19 @@ if __name__ == "__main__":
|
|||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
async def start_all():
|
||||||
|
await prompt_server.setup()
|
||||||
|
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||||
|
|
||||||
|
# Returning these so that other code can integrate with the ComfyUI loop and server
|
||||||
|
return asyncio_loop, prompt_server, start_all
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Running directly, just start ComfyUI.
|
||||||
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(server.setup())
|
event_loop.run_until_complete(start_all_func())
|
||||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user