From 742ca174e2ad6b86878dbb464ef35d3023727faa Mon Sep 17 00:00:00 2001 From: drunkplato <6413077+drunkplato@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:31:53 +0000 Subject: [PATCH] split the queue to consume with 2 channels --- execution.py | 2 +- main.py | 4 +- memedeck-v1.py | 266 ----------------------------------------------- memedeck.py | 154 +++++++++++++++++++-------- requirements.txt | 1 + 5 files changed, 116 insertions(+), 311 deletions(-) delete mode 100644 memedeck-v1.py diff --git a/execution.py b/execution.py index ee99345e..82a66bde 100644 --- a/execution.py +++ b/execution.py @@ -53,7 +53,7 @@ class IsChangedCache: is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: - logging.warning("WARNING: {}".format(e)) + # logging.warning("WARNING: {}".format(e)) node["is_changed"] = float("NaN") finally: self.is_changed[node_id] = node["is_changed"] diff --git a/main.py b/main.py index 3a398564..24f3d7a4 100644 --- a/main.py +++ b/main.py @@ -58,8 +58,8 @@ def apply_custom_paths(): # --------------------------------------------------------------------------------------- from memedeck import MemedeckWorker -# import sys -# sys.stdout = open(os.devnull, 'w') # disable all print statements +import sys +sys.stdout = open(os.devnull, 'w') # disable all print statements # --------------------------------------------------------------------------------------- def execute_prestartup_script(): diff --git a/memedeck-v1.py b/memedeck-v1.py deleted file mode 100644 index abfe7592..00000000 --- a/memedeck-v1.py +++ /dev/null @@ -1,266 +0,0 @@ -import asyncio -import base64 -from io import BytesIO -import os -import logging -import signal -import struct -from typing import Optional -import uuid -from PIL import Image, ImageOps -from functools import partial - -import pika -import json - -import requests -import aiohttp - -# load from env file -# load from .env file -from dotenv import load_dotenv -load_dotenv() - -amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev' - -# define the enum in python -from enum import Enum - -class QueueProgressKind(Enum): - # make json serializable - ImageGenerated = "image_generated" - ImageGenerating = "image_generating" - SamePrompt = "same_prompt" - FaceswapGenerated = "faceswap_generated" - FaceswapGenerating = "faceswap_generating" - Failed = "failed" - -class MemedeckWorker: - class BinaryEventTypes: - PREVIEW_IMAGE = 1 - UNENCODED_PREVIEW_IMAGE = 2 - - class JsonEventTypes(Enum): - PROGRESS = "progress" - EXECUTING = "executing" - EXECUTED = "executed" - ERROR = "error" - STATUS = "status" - - """ - MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api - it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api - """ - def __init__(self, loop): - MemedeckWorker.instance = self - # set logging level to info - logging.getLogger().setLevel(logging.INFO) - self.active_tasks_map = {} - self.current_task = None - - self.client_id = None - self.ws_id = None - self.websocket_node_id = None - self.current_node = None - self.current_progress = 0 - self.current_context = None - - self.loop = loop - self.messages = asyncio.Queue() - - self.http_client = None - self.prompt_queue = None - self.validate_prompt = None - self.last_prompt_id = None - - self.amqp_url = amqp_addr - self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue' - self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2' - self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383' - - # Configure logging - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - self.logger = logging.getLogger(__name__) - self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n") - - def on_connection_open(self, connection): - self.connection = connection - self.connection.channel(on_open_callback=self.on_channel_open) - - def on_channel_open(self, channel): - self.channel = channel - - # only consume one message at a time - self.channel.basic_qos(prefetch_size=0, prefetch_count=1) - self.channel.queue_declare(queue=self.queue_name, durable=True) - self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) - - def start(self, prompt_queue, validate_prompt): - self.prompt_queue = prompt_queue - self.validate_prompt = validate_prompt - - parameters = pika.URLParameters(self.amqp_url) - logging.getLogger('pika').setLevel(logging.WARNING) # supress all logs from pika - self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open) - - try: - self.connection.ioloop.start() - except KeyboardInterrupt: - self.connection.close() - self.connection.ioloop.start() - - def on_message_received(self, channel, method, properties, body): - decoded_string = body.decode('utf-8') - json_object = json.loads(decoded_string) - payload = json_object[1] - - # execute the task - prompt = payload["nodes"] - valid = self.validate_prompt(prompt) - - self.current_node = None - self.current_progress = 0 - self.websocket_node_id = None - self.ws_id = payload["source_ws_id"] - self.current_context = payload["req_ctx"] - - for node in prompt: # search through prompt nodes for websocket_node_id - if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": - self.websocket_node_id = node - break - - if valid[0]: - prompt_id = str(uuid.uuid4()) - outputs_to_execute = valid[2] - self.active_tasks_map[payload["source_ws_id"]] = { - "prompt_id": prompt_id, - "prompt": prompt, - "outputs_to_execute": outputs_to_execute, - "client_id": "memedeck-1", - "is_memedeck": True, - "websocket_node_id": self.websocket_node_id, - "ws_id": payload["source_ws_id"], - "context": payload["req_ctx"], - "current_node": None, - "current_progress": 0, - } - self.prompt_queue.put((0, prompt_id, prompt, { - "client_id": "memedeck-1", - 'is_memedeck': True, - 'websocket_node_id': self.websocket_node_id, - 'ws_id': payload["source_ws_id"], - 'context': payload["req_ctx"] - }, outputs_to_execute)) - self.set_last_prompt_id(prompt_id) - channel.basic_ack(delivery_tag=method.delivery_tag) # ack the task - else: - channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message - - # -------------------------------------------------- - # callbacks for the prompt queue - # -------------------------------------------------- - def queue_updated(self): - # print json of the queue info but only print the first 100 lines - info = self.get_queue_info() - # update_type = info[''] - # self.send_sync("status", { "status": self.get_queue_info() }) - - def get_queue_info(self): - prompt_info = {} - exec_info = {} - exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() - prompt_info['exec_info'] = exec_info - return prompt_info - - def send_sync(self, event, data, sid=None): - - self.loop.call_soon_threadsafe( - self.messages.put_nowait, (event, data, sid)) - - def set_last_prompt_id(self, prompt_id): - self.last_prompt_id = prompt_id - - async def publish_loop(self): - while True: - msg = await self.messages.get() - await self.send(*msg) - - async def send(self, event, data, sid=None): - current_task = self.active_tasks_map.get(sid) - if current_task is None or current_task['ws_id'] != sid: - return - - if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: # preview and unencoded images are sent here - self.logger.info(f"[memedeck]: sending image preview for {sid}") - await self.send_preview(data, sid=current_task['ws_id'], progress=current_task['current_progress'], context=current_task['context']) - else: # send json data / text data - if event == "executing": - current_task['current_node'] = data['node'] - elif event == "executed": - self.logger.info(f"---> [memedeck]: executed event for {sid}") - prompt_id = data['prompt_id'] - if prompt_id in self.active_tasks_map: - del self.active_tasks_map[prompt_id] - elif event == "progress": - if current_task['current_node'] == current_task['websocket_node_id']: # if the node is the websocket node, then set the progress to 100 - current_task['current_progress'] = 100 - else: # if the node is not the websocket node, then set the progress to the progress from the node - current_task['current_progress'] = data['value'] / data['max'] * 100 - if current_task['current_progress'] == 100 and current_task['current_node'] != current_task['websocket_node_id']: - # in case the progress is 100 but the node is not the websocket node, then set the progress to 95 - current_task['current_progress'] = 95 # this allows the full resolution image to be sent on the 100 progress event - - if data['value'] == 1: # if the value is 1, then send started to api - start_data = { - "ws_id": current_task['ws_id'], - "status": "started", - "info": None, - } - await self.send_to_api(start_data) - - elif event == "status": - self.logger.info(f"[memedeck]: sending status event: {data}") - - self.active_tasks_map[sid] = current_task - - - async def send_preview(self, image_data, sid=None, progress=None, context=None): - # if self.current_progress is odd, then don't send the preview - if progress % 2 == 1: - return - - image_type = image_data[0] - image = image_data[1] - max_size = image_data[2] - if max_size is not None: - if hasattr(Image, 'Resampling'): - resampling = Image.Resampling.BILINEAR - else: - resampling = Image.ANTIALIAS - - image = ImageOps.contain(image, (max_size, max_size), resampling) - - bytesIO = BytesIO() - image.save(bytesIO, format=image_type, quality=100 if progress == 96 else 75, compress_level=1) - preview_bytes = bytesIO.getvalue() - - ai_queue_progress = { - "ws_id": sid, - "kind": "image_generating" if progress < 100 else "image_generated", - "data": list(preview_bytes), - "progress": int(progress), - "context": context - } - - await self.send_to_api(ai_queue_progress) - - async def send_to_api(self, data): - if self.websocket_node_id is None: # check if the node is still running - logging.error(f"[memedeck]: websocket_node_id is None for {data['ws_id']}") - return - - try: - post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data) - await self.loop.run_in_executor(None, post_func) - except Exception as e: - self.logger.error(f"[memedeck]: error sending to api: {e}") diff --git a/memedeck.py b/memedeck.py index d4c8e3c8..1c2cb31f 100644 --- a/memedeck.py +++ b/memedeck.py @@ -1,12 +1,7 @@ import asyncio -import base64 from io import BytesIO import os import logging -import signal -import struct -import time -from typing import Optional import uuid from PIL import Image, ImageOps from functools import partial @@ -15,7 +10,6 @@ import pika import json import requests -import aiohttp from dotenv import load_dotenv load_dotenv() @@ -77,21 +71,6 @@ class MemedeckWorker: self.logger = logging.getLogger(__name__) self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n") - def on_connection_open(self, connection): - self.connection = connection - self.connection.channel(on_open_callback=self.on_channel_open) - - def on_channel_open(self, channel): - self.channel = channel - - # only consume one message at a time - self.channel.basic_qos(prefetch_size=0, prefetch_count=1) - self.channel.queue_declare(queue=self.queue_name, durable=True) - self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False) - # declare another queue - self.channel.queue_declare(queue='faceswap-queue', durable=True) - self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received, auto_ack=False) - def start(self, prompt_queue, validate_prompt): self.prompt_queue = prompt_queue self.validate_prompt = validate_prompt @@ -106,24 +85,56 @@ class MemedeckWorker: try: self.connection.ioloop.start() except KeyboardInterrupt: - self.connection.close() - self.connection.ioloop.start() + self.stop() + + def stop(self): + self.connection.close() + self.connection.ioloop.stop() + + # -------------------------------------------------- + # AMQP + # -------------------------------------------------- + def on_connection_open(self, connection): + self.connection = connection + # Open the first channel + self.connection.channel(on_open_callback=self.on_channel_open) + # Open the second channel + self.connection.channel(on_open_callback=self.on_faceswap_channel_open) + + def on_channel_open(self, channel): + self.channel = channel + # Only consume one message at a time + self.channel.basic_qos(prefetch_size=0, prefetch_count=1) + # Declare the queue and set the callback + self.channel.queue_declare(queue=self.queue_name, durable=True, callback=self.on_queue_declared) + + def on_queue_declared(self, frame): + self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) + + def on_faceswap_channel_open(self, channel): + self.faceswap_channel = channel + self.faceswap_channel.basic_qos(prefetch_size=0, prefetch_count=1) + # Declare the faceswap queue and set the callback + self.faceswap_channel.queue_declare(queue='faceswap-queue', durable=True, callback=self.on_faceswap_queue_declared) + + def on_faceswap_queue_declared(self, frame): + self.faceswap_channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_faceswap_message_received) + + def on_faceswap_message_received(self, channel, method, properties, body): + self.on_message_received(channel, method, properties, body) - def on_message_received(self, channel, method, properties, body): + def on_message_received(self, channel, method, properties, body): decoded_string = body.decode('utf-8') - json_object = json.loads(decoded_string) - payload = json_object[1] - - # execute the task + payload = json.loads(decoded_string) + + # Execute the task prompt = payload["nodes"] - valid = self.validate_prompt(prompt) - + valid = self.validate_prompt(prompt) + # Prepare task_info prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] - # get the routing key from the method routing_key = method.routing_key - self.logger.info(f"[memedeck]: routing_key: {routing_key}") task_info = { "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation', "prompt_id": prompt_id, @@ -149,10 +160,66 @@ class MemedeckWorker: if valid[0]: # Enqueue the task into the internal job queue self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info)) - # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}") else: - channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message + channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # Unack the message + # def on_channel_open(self, channel): + # self.channel = channel + + # # only consume one message at a time + # self.channel.basic_qos(prefetch_size=0, prefetch_count=1) + # self.channel.queue_declare(queue=self.queue_name, durable=True) + # self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) + # # declare another queue + # self.channel.queue_declare(queue='faceswap-queue', durable=True) + # self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received) + + # def on_message_received(self, channel, method, properties, body): + # decoded_string = body.decode('utf-8') + # json_object = json.loads(decoded_string) + # payload = json_object[1] + + # # execute the task + # prompt = payload["nodes"] + # valid = self.validate_prompt(prompt) + + # # Prepare task_info + # prompt_id = str(uuid.uuid4()) + # outputs_to_execute = valid[2] + # # get the routing key from the method + # routing_key = method.routing_key + # task_info = { + # "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation', + # "prompt_id": prompt_id, + # "prompt": prompt, + # "outputs_to_execute": outputs_to_execute, + # "client_id": "memedeck-1", + # "is_memedeck": True, + # "websocket_node_id": None, + # "ws_id": payload["source_ws_id"], + # "context": payload["req_ctx"], + # "current_node": None, + # "current_progress": 0, + # "delivery_tag": method.delivery_tag, + # "task_status": "waiting", + # } + + # # Find the websocket_node_id + # for node in prompt: + # if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": + # task_info['websocket_node_id'] = node + # break + + # if valid[0]: + # # Enqueue the task into the internal job queue + # self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info)) + # # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}") + # else: + # channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message + + # -------------------------------------------------- + # Internal job queue + # -------------------------------------------------- async def process_job_queue(self): while True: prompt_id, prompt, task_info = await self.internal_job_queue.get() @@ -172,13 +239,9 @@ class MemedeckWorker: 'context': task_info['context'] }, task_info['outputs_to_execute'])) - # Acknowledge the message - self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task - if 'faceswap_strength' not in task_info['context']['prompt_config']: # pretty print the prompt config self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}") - # Wait until the current task is completed await self.wait_for_task_completion(ws_id) # Task is done @@ -188,11 +251,19 @@ class MemedeckWorker: """ Wait until the task with the given ws_id is completed. """ + task = self.tasks_by_ws_id[ws_id] + delivery_tag = task['delivery_tag'] while ws_id in self.tasks_by_ws_id: - await asyncio.sleep(0.5) + await asyncio.sleep(0.25) + + # Acknowledge the message when the task is completed + if task['workflow'] == 'faceswap': + self.faceswap_channel.basic_ack(delivery_tag=delivery_tag) + else: + self.channel.basic_ack(delivery_tag=delivery_tag) # -------------------------------------------------- - # callbacks for the prompt queue + # allbacks for the prompt queue # -------------------------------------------------- def queue_updated(self): info = self.get_queue_info() @@ -320,8 +391,6 @@ class MemedeckWorker: if workflow == 'faceswap': ai_queue_progress['kind'] = "faceswap_generated" del ai_queue_progress['progress'] - # dont print the data field without deleting it - self.logger.info(f"[memedeck]: sending faceswap result") await self.send_to_api(ai_queue_progress) @@ -597,3 +666,4 @@ class MemedeckWorker: # average_brightness = total_brightness // pixel_count # return average_brightness + diff --git a/requirements.txt b/requirements.txt index 23ac0295..26ab9048 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ python-dotenv pillow azure-storage-blob cairosvg +aio_pika