split the queue to consume with 2 channels

This commit is contained in:
drunkplato 2024-10-18 11:31:53 +00:00 committed by Ubuntu
parent 439a914562
commit 742ca174e2
5 changed files with 116 additions and 311 deletions

View File

@ -53,7 +53,7 @@ class IsChangedCache:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
# logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN")
finally:
self.is_changed[node_id] = node["is_changed"]

View File

@ -58,8 +58,8 @@ def apply_custom_paths():
# ---------------------------------------------------------------------------------------
from memedeck import MemedeckWorker
# import sys
# sys.stdout = open(os.devnull, 'w') # disable all print statements
import sys
sys.stdout = open(os.devnull, 'w') # disable all print statements
# ---------------------------------------------------------------------------------------
def execute_prestartup_script():

View File

@ -1,266 +0,0 @@
import asyncio
import base64
from io import BytesIO
import os
import logging
import signal
import struct
from typing import Optional
import uuid
from PIL import Image, ImageOps
from functools import partial
import pika
import json
import requests
import aiohttp
# load from env file
# load from .env file
from dotenv import load_dotenv
load_dotenv()
amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev'
# define the enum in python
from enum import Enum
class QueueProgressKind(Enum):
# make json serializable
ImageGenerated = "image_generated"
ImageGenerating = "image_generating"
SamePrompt = "same_prompt"
FaceswapGenerated = "faceswap_generated"
FaceswapGenerating = "faceswap_generating"
Failed = "failed"
class MemedeckWorker:
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
class JsonEventTypes(Enum):
PROGRESS = "progress"
EXECUTING = "executing"
EXECUTED = "executed"
ERROR = "error"
STATUS = "status"
"""
MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api
it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api
"""
def __init__(self, loop):
MemedeckWorker.instance = self
# set logging level to info
logging.getLogger().setLevel(logging.INFO)
self.active_tasks_map = {}
self.current_task = None
self.client_id = None
self.ws_id = None
self.websocket_node_id = None
self.current_node = None
self.current_progress = 0
self.current_context = None
self.loop = loop
self.messages = asyncio.Queue()
self.http_client = None
self.prompt_queue = None
self.validate_prompt = None
self.last_prompt_id = None
self.amqp_url = amqp_addr
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n")
def on_connection_open(self, connection):
self.connection = connection
self.connection.channel(on_open_callback=self.on_channel_open)
def on_channel_open(self, channel):
self.channel = channel
# only consume one message at a time
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
self.channel.queue_declare(queue=self.queue_name, durable=True)
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
def start(self, prompt_queue, validate_prompt):
self.prompt_queue = prompt_queue
self.validate_prompt = validate_prompt
parameters = pika.URLParameters(self.amqp_url)
logging.getLogger('pika').setLevel(logging.WARNING) # supress all logs from pika
self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open)
try:
self.connection.ioloop.start()
except KeyboardInterrupt:
self.connection.close()
self.connection.ioloop.start()
def on_message_received(self, channel, method, properties, body):
decoded_string = body.decode('utf-8')
json_object = json.loads(decoded_string)
payload = json_object[1]
# execute the task
prompt = payload["nodes"]
valid = self.validate_prompt(prompt)
self.current_node = None
self.current_progress = 0
self.websocket_node_id = None
self.ws_id = payload["source_ws_id"]
self.current_context = payload["req_ctx"]
for node in prompt: # search through prompt nodes for websocket_node_id
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
self.websocket_node_id = node
break
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.active_tasks_map[payload["source_ws_id"]] = {
"prompt_id": prompt_id,
"prompt": prompt,
"outputs_to_execute": outputs_to_execute,
"client_id": "memedeck-1",
"is_memedeck": True,
"websocket_node_id": self.websocket_node_id,
"ws_id": payload["source_ws_id"],
"context": payload["req_ctx"],
"current_node": None,
"current_progress": 0,
}
self.prompt_queue.put((0, prompt_id, prompt, {
"client_id": "memedeck-1",
'is_memedeck': True,
'websocket_node_id': self.websocket_node_id,
'ws_id': payload["source_ws_id"],
'context': payload["req_ctx"]
}, outputs_to_execute))
self.set_last_prompt_id(prompt_id)
channel.basic_ack(delivery_tag=method.delivery_tag) # ack the task
else:
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
# --------------------------------------------------
# callbacks for the prompt queue
# --------------------------------------------------
def queue_updated(self):
# print json of the queue info but only print the first 100 lines
info = self.get_queue_info()
# update_type = info['']
# self.send_sync("status", { "status": self.get_queue_info() })
def get_queue_info(self):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
def set_last_prompt_id(self, prompt_id):
self.last_prompt_id = prompt_id
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(*msg)
async def send(self, event, data, sid=None):
current_task = self.active_tasks_map.get(sid)
if current_task is None or current_task['ws_id'] != sid:
return
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: # preview and unencoded images are sent here
self.logger.info(f"[memedeck]: sending image preview for {sid}")
await self.send_preview(data, sid=current_task['ws_id'], progress=current_task['current_progress'], context=current_task['context'])
else: # send json data / text data
if event == "executing":
current_task['current_node'] = data['node']
elif event == "executed":
self.logger.info(f"---> [memedeck]: executed event for {sid}")
prompt_id = data['prompt_id']
if prompt_id in self.active_tasks_map:
del self.active_tasks_map[prompt_id]
elif event == "progress":
if current_task['current_node'] == current_task['websocket_node_id']: # if the node is the websocket node, then set the progress to 100
current_task['current_progress'] = 100
else: # if the node is not the websocket node, then set the progress to the progress from the node
current_task['current_progress'] = data['value'] / data['max'] * 100
if current_task['current_progress'] == 100 and current_task['current_node'] != current_task['websocket_node_id']:
# in case the progress is 100 but the node is not the websocket node, then set the progress to 95
current_task['current_progress'] = 95 # this allows the full resolution image to be sent on the 100 progress event
if data['value'] == 1: # if the value is 1, then send started to api
start_data = {
"ws_id": current_task['ws_id'],
"status": "started",
"info": None,
}
await self.send_to_api(start_data)
elif event == "status":
self.logger.info(f"[memedeck]: sending status event: {data}")
self.active_tasks_map[sid] = current_task
async def send_preview(self, image_data, sid=None, progress=None, context=None):
# if self.current_progress is odd, then don't send the preview
if progress % 2 == 1:
return
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=100 if progress == 96 else 75, compress_level=1)
preview_bytes = bytesIO.getvalue()
ai_queue_progress = {
"ws_id": sid,
"kind": "image_generating" if progress < 100 else "image_generated",
"data": list(preview_bytes),
"progress": int(progress),
"context": context
}
await self.send_to_api(ai_queue_progress)
async def send_to_api(self, data):
if self.websocket_node_id is None: # check if the node is still running
logging.error(f"[memedeck]: websocket_node_id is None for {data['ws_id']}")
return
try:
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
await self.loop.run_in_executor(None, post_func)
except Exception as e:
self.logger.error(f"[memedeck]: error sending to api: {e}")

View File

@ -1,12 +1,7 @@
import asyncio
import base64
from io import BytesIO
import os
import logging
import signal
import struct
import time
from typing import Optional
import uuid
from PIL import Image, ImageOps
from functools import partial
@ -15,7 +10,6 @@ import pika
import json
import requests
import aiohttp
from dotenv import load_dotenv
load_dotenv()
@ -77,21 +71,6 @@ class MemedeckWorker:
self.logger = logging.getLogger(__name__)
self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n")
def on_connection_open(self, connection):
self.connection = connection
self.connection.channel(on_open_callback=self.on_channel_open)
def on_channel_open(self, channel):
self.channel = channel
# only consume one message at a time
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
self.channel.queue_declare(queue=self.queue_name, durable=True)
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False)
# declare another queue
self.channel.queue_declare(queue='faceswap-queue', durable=True)
self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received, auto_ack=False)
def start(self, prompt_queue, validate_prompt):
self.prompt_queue = prompt_queue
self.validate_prompt = validate_prompt
@ -106,24 +85,56 @@ class MemedeckWorker:
try:
self.connection.ioloop.start()
except KeyboardInterrupt:
self.connection.close()
self.connection.ioloop.start()
self.stop()
def stop(self):
self.connection.close()
self.connection.ioloop.stop()
# --------------------------------------------------
# AMQP
# --------------------------------------------------
def on_connection_open(self, connection):
self.connection = connection
# Open the first channel
self.connection.channel(on_open_callback=self.on_channel_open)
# Open the second channel
self.connection.channel(on_open_callback=self.on_faceswap_channel_open)
def on_channel_open(self, channel):
self.channel = channel
# Only consume one message at a time
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
# Declare the queue and set the callback
self.channel.queue_declare(queue=self.queue_name, durable=True, callback=self.on_queue_declared)
def on_queue_declared(self, frame):
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
def on_faceswap_channel_open(self, channel):
self.faceswap_channel = channel
self.faceswap_channel.basic_qos(prefetch_size=0, prefetch_count=1)
# Declare the faceswap queue and set the callback
self.faceswap_channel.queue_declare(queue='faceswap-queue', durable=True, callback=self.on_faceswap_queue_declared)
def on_faceswap_queue_declared(self, frame):
self.faceswap_channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_faceswap_message_received)
def on_faceswap_message_received(self, channel, method, properties, body):
self.on_message_received(channel, method, properties, body)
def on_message_received(self, channel, method, properties, body):
def on_message_received(self, channel, method, properties, body):
decoded_string = body.decode('utf-8')
json_object = json.loads(decoded_string)
payload = json_object[1]
# execute the task
payload = json.loads(decoded_string)
# Execute the task
prompt = payload["nodes"]
valid = self.validate_prompt(prompt)
valid = self.validate_prompt(prompt)
# Prepare task_info
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
# get the routing key from the method
routing_key = method.routing_key
self.logger.info(f"[memedeck]: routing_key: {routing_key}")
task_info = {
"workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
"prompt_id": prompt_id,
@ -149,10 +160,66 @@ class MemedeckWorker:
if valid[0]:
# Enqueue the task into the internal job queue
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
# self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
else:
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # Unack the message
# def on_channel_open(self, channel):
# self.channel = channel
# # only consume one message at a time
# self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
# self.channel.queue_declare(queue=self.queue_name, durable=True)
# self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
# # declare another queue
# self.channel.queue_declare(queue='faceswap-queue', durable=True)
# self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received)
# def on_message_received(self, channel, method, properties, body):
# decoded_string = body.decode('utf-8')
# json_object = json.loads(decoded_string)
# payload = json_object[1]
# # execute the task
# prompt = payload["nodes"]
# valid = self.validate_prompt(prompt)
# # Prepare task_info
# prompt_id = str(uuid.uuid4())
# outputs_to_execute = valid[2]
# # get the routing key from the method
# routing_key = method.routing_key
# task_info = {
# "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
# "prompt_id": prompt_id,
# "prompt": prompt,
# "outputs_to_execute": outputs_to_execute,
# "client_id": "memedeck-1",
# "is_memedeck": True,
# "websocket_node_id": None,
# "ws_id": payload["source_ws_id"],
# "context": payload["req_ctx"],
# "current_node": None,
# "current_progress": 0,
# "delivery_tag": method.delivery_tag,
# "task_status": "waiting",
# }
# # Find the websocket_node_id
# for node in prompt:
# if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
# task_info['websocket_node_id'] = node
# break
# if valid[0]:
# # Enqueue the task into the internal job queue
# self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
# # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
# else:
# channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
# --------------------------------------------------
# Internal job queue
# --------------------------------------------------
async def process_job_queue(self):
while True:
prompt_id, prompt, task_info = await self.internal_job_queue.get()
@ -172,13 +239,9 @@ class MemedeckWorker:
'context': task_info['context']
}, task_info['outputs_to_execute']))
# Acknowledge the message
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
if 'faceswap_strength' not in task_info['context']['prompt_config']:
# pretty print the prompt config
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
# Wait until the current task is completed
await self.wait_for_task_completion(ws_id)
# Task is done
@ -188,11 +251,19 @@ class MemedeckWorker:
"""
Wait until the task with the given ws_id is completed.
"""
task = self.tasks_by_ws_id[ws_id]
delivery_tag = task['delivery_tag']
while ws_id in self.tasks_by_ws_id:
await asyncio.sleep(0.5)
await asyncio.sleep(0.25)
# Acknowledge the message when the task is completed
if task['workflow'] == 'faceswap':
self.faceswap_channel.basic_ack(delivery_tag=delivery_tag)
else:
self.channel.basic_ack(delivery_tag=delivery_tag)
# --------------------------------------------------
# callbacks for the prompt queue
# allbacks for the prompt queue
# --------------------------------------------------
def queue_updated(self):
info = self.get_queue_info()
@ -320,8 +391,6 @@ class MemedeckWorker:
if workflow == 'faceswap':
ai_queue_progress['kind'] = "faceswap_generated"
del ai_queue_progress['progress']
# dont print the data field without deleting it
self.logger.info(f"[memedeck]: sending faceswap result")
await self.send_to_api(ai_queue_progress)
@ -597,3 +666,4 @@ class MemedeckWorker:
# average_brightness = total_brightness // pixel_count
# return average_brightness

View File

@ -29,3 +29,4 @@ python-dotenv
pillow
azure-storage-blob
cairosvg
aio_pika