import asyncio from io import BytesIO import os import logging import uuid from PIL import Image, ImageOps from functools import partial import pika import json import requests from dotenv import load_dotenv load_dotenv() amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev' from enum import Enum class QueueProgressKind(Enum): ImageGenerated = "image_generated" ImageGenerating = "image_generating" SamePrompt = "same_prompt" FaceswapGenerated = "faceswap_generated" FaceswapGenerating = "faceswap_generating" Failed = "failed" class MemedeckWorker: class BinaryEventTypes: PREVIEW_IMAGE = 1 UNENCODED_PREVIEW_IMAGE = 2 class JsonEventTypes(Enum): PROGRESS = "progress" EXECUTING = "executing" EXECUTED = "executed" ERROR = "error" STATUS = "status" """ MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api """ def __init__(self, loop): MemedeckWorker.instance = self logging.getLogger().setLevel(logging.INFO) self.loop = loop self.messages = asyncio.Queue() self.ws_id = None self.http_client = None self.prompt_queue = None self.validate_prompt = None self.last_prompt_id = None self.amqp_url = amqp_addr self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue' self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2' self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383' # Internal job queue self.internal_job_queue = asyncio.Queue() # Dictionary to keep track of tasks by ws_id self.tasks_by_ws_id = {} # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') self.logger = logging.getLogger(__name__) self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n") def start(self, prompt_queue, validate_prompt): self.prompt_queue = prompt_queue self.validate_prompt = validate_prompt # Start the process_job_queue task **after** prompt_queue is set self.loop.create_task(self.process_job_queue()) parameters = pika.URLParameters(self.amqp_url) logging.getLogger('pika').setLevel(logging.WARNING) # suppress all logs from pika self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open) try: self.connection.ioloop.start() except KeyboardInterrupt: self.stop() def stop(self): self.connection.close() self.connection.ioloop.stop() # -------------------------------------------------- # AMQP # -------------------------------------------------- def on_connection_open(self, connection): self.connection = connection # Open the first channel self.connection.channel(on_open_callback=self.on_channel_open) # Open the second channel self.connection.channel(on_open_callback=self.on_faceswap_channel_open) def on_channel_open(self, channel): self.channel = channel # Only consume one message at a time self.channel.basic_qos(prefetch_size=0, prefetch_count=1) # Declare the queue and set the callback self.channel.queue_declare(queue=self.queue_name, durable=True, callback=self.on_queue_declared) def on_queue_declared(self, frame): self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) def on_faceswap_channel_open(self, channel): self.faceswap_channel = channel self.faceswap_channel.basic_qos(prefetch_size=0, prefetch_count=1) # Declare the faceswap queue and set the callback self.faceswap_channel.queue_declare(queue='faceswap-queue', durable=True, callback=self.on_faceswap_queue_declared) def on_faceswap_queue_declared(self, frame): self.faceswap_channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_faceswap_message_received) def on_faceswap_message_received(self, channel, method, properties, body): self.on_message_received(channel, method, properties, body) def on_message_received(self, channel, method, properties, body): decoded_string = body.decode('utf-8') payload = json.loads(decoded_string) # Execute the task prompt = payload["nodes"] valid = self.validate_prompt(prompt) # Prepare task_info prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] routing_key = method.routing_key task_info = { "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation', "prompt_id": prompt_id, "prompt": prompt, "outputs_to_execute": outputs_to_execute, "client_id": "memedeck-1", "is_memedeck": True, "websocket_node_id": None, "ws_id": payload["source_ws_id"], "context": payload["req_ctx"], "current_node": None, "current_progress": 0, "delivery_tag": method.delivery_tag, "task_status": "waiting", } # Find the websocket_node_id for node in prompt: if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": task_info['websocket_node_id'] = node break if valid[0]: # Enqueue the task into the internal job queue self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info)) else: channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # Unack the message # def on_channel_open(self, channel): # self.channel = channel # # only consume one message at a time # self.channel.basic_qos(prefetch_size=0, prefetch_count=1) # self.channel.queue_declare(queue=self.queue_name, durable=True) # self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) # # declare another queue # self.channel.queue_declare(queue='faceswap-queue', durable=True) # self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received) # def on_message_received(self, channel, method, properties, body): # decoded_string = body.decode('utf-8') # json_object = json.loads(decoded_string) # payload = json_object[1] # # execute the task # prompt = payload["nodes"] # valid = self.validate_prompt(prompt) # # Prepare task_info # prompt_id = str(uuid.uuid4()) # outputs_to_execute = valid[2] # # get the routing key from the method # routing_key = method.routing_key # task_info = { # "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation', # "prompt_id": prompt_id, # "prompt": prompt, # "outputs_to_execute": outputs_to_execute, # "client_id": "memedeck-1", # "is_memedeck": True, # "websocket_node_id": None, # "ws_id": payload["source_ws_id"], # "context": payload["req_ctx"], # "current_node": None, # "current_progress": 0, # "delivery_tag": method.delivery_tag, # "task_status": "waiting", # } # # Find the websocket_node_id # for node in prompt: # if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": # task_info['websocket_node_id'] = node # break # if valid[0]: # # Enqueue the task into the internal job queue # self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info)) # # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}") # else: # channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message # -------------------------------------------------- # Internal job queue # -------------------------------------------------- async def process_job_queue(self): while True: prompt_id, prompt, task_info = await self.internal_job_queue.get() # Start a new coroutine for each task self.loop.create_task(self.process_task(prompt_id, prompt, task_info)) async def process_task(self, prompt_id, prompt, task_info): ws_id = task_info['ws_id'] # Add the task to tasks_by_ws_id self.tasks_by_ws_id[ws_id] = task_info # Put the prompt into the prompt_queue self.prompt_queue.put((0, prompt_id, prompt, { "client_id": task_info["client_id"], 'is_memedeck': task_info['is_memedeck'], 'websocket_node_id': task_info['websocket_node_id'], 'ws_id': task_info['ws_id'], 'context': task_info['context'] }, task_info['outputs_to_execute'])) if 'faceswap_strength' not in task_info['context']['prompt_config']: # pretty print the prompt config self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}") # Wait until the current task is completed await self.wait_for_task_completion(ws_id) # Task is done self.internal_job_queue.task_done() async def wait_for_task_completion(self, ws_id): """ Wait until the task with the given ws_id is completed. """ task = self.tasks_by_ws_id[ws_id] delivery_tag = task['delivery_tag'] while ws_id in self.tasks_by_ws_id: await asyncio.sleep(0.25) # Acknowledge the message when the task is completed if task['workflow'] == 'faceswap': self.faceswap_channel.basic_ack(delivery_tag=delivery_tag) else: self.channel.basic_ack(delivery_tag=delivery_tag) # -------------------------------------------------- # allbacks for the prompt queue # -------------------------------------------------- def queue_updated(self): info = self.get_queue_info() # self.send_sync("status", { "status": self.get_queue_info() }) def get_queue_info(self): prompt_info = {} exec_info = {} exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() prompt_info['exec_info'] = exec_info return prompt_info def send_sync(self, event, data, sid=None): self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) async def publish_loop(self): while True: msg = await self.messages.get() await self.send(*msg) async def send(self, event, data, sid=None): if sid is None: self.logger.warning("Received event without sid") return # Retrieve the task based on sid task = self.tasks_by_ws_id.get(sid) if not task: self.logger.warning(f"Received event {event} for unknown sid: {sid}") return if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: await self.send_preview( data, sid=sid, progress=task['current_progress'], context=task['context'], workflow=task['workflow'] ) else: # Send JSON data / text data if event == "executing": task['current_node'] = data['node'] if task['workflow'] == 'faceswap' and task["task_status"] == "waiting": start_data = { "ws_id": task['ws_id'], "status": "started", "info": None, } await self.send_to_api(start_data) # self.logger.info(f"[memedeck]: faceswap executing: {data}") task["task_status"] = "executing" elif event == "progress": if task['current_node'] == task['websocket_node_id']: # If the node is the websocket node, then set the progress to 100 task['current_progress'] = 100 else: # If the node is not the websocket node, then set the progress based on the node's progress task['current_progress'] = (data['value'] / data['max']) * 100 if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']: # In case the progress is 100 but the node is not the websocket node, set progress to 95 task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event if data['value'] == 1: # If the value is 1, send started to API start_data = { "ws_id": task['ws_id'], "status": "started", "info": None, } task["task_status"] = "executing" await self.send_to_api(start_data) elif event == "status": self.logger.info(f"[memedeck]: sending status event: {data}") # Update the task in tasks_by_ws_id self.tasks_by_ws_id[sid] = task async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None): if sid is None: self.logger.warning("Received preview without sid") return task = self.tasks_by_ws_id.get(sid) if not task: self.logger.warning(f"Received preview for unknown sid: {sid}") return if progress is None: progress = task['current_progress'] # if progress is odd, then don't send the preview if int(progress) % 2 == 1: return image_type = image_data[0] image = image_data[1] max_size = image_data[2] if max_size is not None: if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR else: resampling = Image.ANTIALIAS image = ImageOps.contain(image, (max_size, max_size), resampling) bytesIO = BytesIO() image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1) preview_bytes = bytesIO.getvalue() kind = "image_generating" if progress < 100 else "image_generated" ai_queue_progress = { "ws_id": sid, "kind": kind, "data": list(preview_bytes), "progress": int(progress), "context": context } # set the kind to faceswap_generated if workflow is faceswap if workflow == 'faceswap': ai_queue_progress['kind'] = "faceswap_generated" del ai_queue_progress['progress'] await self.send_to_api(ai_queue_progress) if progress == 100 or workflow == 'faceswap': del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id # self.logger.info(f"[memedeck]: Task {sid} completed") async def send_to_api(self, data): ws_id = data.get('ws_id') if not ws_id: self.logger.error("[memedeck]: Missing ws_id in data") return task = self.tasks_by_ws_id.get(ws_id) if not task: self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}") return if task['websocket_node_id'] is None: self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}") return try: # this request is not sending properly for faceswap post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data) await self.loop.run_in_executor(None, post_func) except Exception as e: self.logger.info(f"[memedeck]: error sending to api: {e}") # -------------------------------------------------------------------------- # MemedeckAzureStorage # -------------------------------------------------------------------------- # from azure.storage.blob.aio import BlobClient, BlobServiceClient # from azure.storage.blob import ContentSettings # from typing import Optional, Tuple # import cairosvg # WATERMARK = '' # WATERMARK_SIZE = 40 # class MemedeckAzureStorage: # def __init__(self, connection_string): # # get environment variables # self.storage_account = os.getenv('STORAGE_ACCOUNT') # self.storage_access_key = os.getenv('STORAGE_ACCESS_KEY') # self.storage_container = os.getenv('STORAGE_CONTAINER') # self.logger = logging.getLogger(__name__) # self.blob_service_client = BlobServiceClient.from_connection_string(conn_str=connection_string) # async def upload_image( # self, # by: str, # image_id: str, # source_url: Optional[str], # bytes_data: Optional[bytes], # filetype: Optional[str], # ) -> Tuple[str, Tuple[int, int]]: # """ # Uploads an image to Azure Blob Storage. # Args: # by (str): Identifier for the uploader. # image_id (str): Unique identifier for the image. # source_url (Optional[str]): URL to fetch the image from. # bytes_data (Optional[bytes]): Image data in bytes. # filetype (Optional[str]): Desired file type (e.g., 'jpeg', 'png'). # Returns: # Tuple[str, Tuple[int, int]]: URL of the uploaded image and its dimensions. # """ # # Retrieve image bytes either from the provided bytes_data or by fetching from source_url # if source_url is None: # if bytes_data is None: # raise ValueError("Could not get image bytes") # image_bytes = bytes_data # else: # self.logger.info(f"Requesting image from URL: {source_url}") # async with aiohttp.ClientSession() as session: # try: # async with session.get(source_url) as response: # if response.status != 200: # raise Exception(f"Failed to fetch image, status code {response.status}") # image_bytes = await response.read() # except Exception as e: # raise Exception(f"Error fetching image from URL: {e}") # # Open image using Pillow to get dimensions and format # try: # img = Image.open(BytesIO(image_bytes)) # width, height = img.size # inferred_filetype = img.format.lower() # except Exception as e: # raise Exception(f"Failed to decode image: {e}") # # Determine the final file type # final_filetype = filetype.lower() if filetype else inferred_filetype # # Construct the blob name # blob_name = f"{by}/{image_id.replace('image:', '')}.{final_filetype}" # # Upload the image to Azure Blob Storage # try: # image_url = await self.save_image(blob_name, img.format, image_bytes) # return image_url, (width, height) # except Exception as e: # self.logger.error(f"Trouble saving image: {e}") # raise Exception(f"Trouble saving image: {e}") # async def save_image( # self, # blob_name: str, # content_type: str, # bytes_data: bytes # ) -> str: # """ # Saves image bytes to Azure Blob Storage. # Args: # blob_name (str): Name of the blob in Azure Storage. # content_type (str): MIME type of the content. # bytes_data (bytes): Image data in bytes. # Returns: # str: URL of the uploaded blob. # """ # # Retrieve environment variables # account = os.getenv("STORAGE_ACCOUNT") # access_key = os.getenv("STORAGE_ACCESS_KEY") # container = os.getenv("STORAGE_CONTAINER") # if not all([account, access_key, container]): # raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables") # # Initialize BlobServiceClient # blob_service_client = BlobServiceClient( # account_url=f"https://{account}.blob.core.windows.net", # credential=access_key # ) # blob_client = blob_service_client.get_blob_client(container=container, blob=blob_name) # # Upload the blob # try: # await blob_client.upload_blob( # bytes_data, # overwrite=True, # content_settings=ContentSettings(content_type=content_type) # ) # except Exception as e: # raise Exception(f"Failed to upload blob: {e}") # self.logger.debug(f"Blob uploaded: name={blob_name}, content_type={content_type}") # # Construct and return the blob URL # blob_url = f"https://media.memedeck.xyz//{container}/{blob_name}" # return blob_url # async def add_watermark( # self, # base_blob_name: str, # base_image: bytes # ) -> str: # """ # Adds a watermark to the provided image and uploads the watermarked image. # Args: # base_blob_name (str): Original blob name of the image. # base_image (bytes): Image data in bytes. # Returns: # str: URL of the watermarked image. # """ # # Load the input image # try: # img = Image.open(BytesIO(base_image)).convert("RGBA") # except Exception as e: # raise Exception(f"Failed to load image: {e}") # # Calculate position for the watermark (bottom right corner with padding) # padding = 12 # x = img.width - WATERMARK_SIZE - padding # y = img.height - WATERMARK_SIZE - padding # # Analyze background brightness where the watermark will be placed # background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE) # self.logger.info(f"Background brightness: {background_brightness}") # # Render SVG watermark to PNG bytes using cairosvg # try: # watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE) # watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA") # except Exception as e: # raise Exception(f"Failed to render watermark SVG: {e}") # # Determine watermark color based on background brightness # if background_brightness > 128: # # Dark watermark for light backgrounds # watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity # else: # # Light watermark for dark backgrounds # watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity # # Apply the watermark color by blending # solid_color = Image.new("RGBA", watermark.size, watermark_color) # watermark = Image.alpha_composite(watermark, solid_color) # # Overlay the watermark onto the original image # img.paste(watermark, (x, y), watermark) # # Save the watermarked image to bytes # buffer = BytesIO() # img = img.convert("RGB") # Convert back to RGB for JPEG format # img.save(buffer, format="JPEG") # buffer.seek(0) # jpeg_bytes = buffer.read() # # Modify the blob name to include '_watermarked' # try: # if "memes/" in base_blob_name: # base_blob_name_right = base_blob_name.split("memes/", 1)[1] # else: # base_blob_name_right = base_blob_name # base_blob_name_split = base_blob_name_right.rsplit(".", 1) # base_blob_name_without_extension = base_blob_name_split[0] # extension = base_blob_name_split[1] # except Exception as e: # raise Exception(f"Failed to process blob name: {e}") # watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}" # # Upload the watermarked image # try: # watermarked_blob_url = await self.save_image( # watermarked_blob_name, # "image/jpeg", # jpeg_bytes # ) # return watermarked_blob_url # except Exception as e: # raise Exception(f"Failed to upload watermarked image: {e}") # def analyze_background_brightness( # self, # img: Image.Image, # x: int, # y: int, # size: int # ) -> int: # """ # Analyzes the brightness of a specific region in the image. # Args: # img (Image.Image): The image to analyze. # x (int): X-coordinate of the top-left corner of the region. # y (int): Y-coordinate of the top-left corner of the region. # size (int): Size of the square region to analyze. # Returns: # int: Average brightness (0-255) of the region. # """ # # Crop the specified region # sub_image = img.crop((x, y, x + size, y + size)).convert("RGB") # # Calculate average brightness using the luminance formula # total_brightness = 0 # pixel_count = 0 # for pixel in sub_image.getdata(): # r, g, b = pixel # brightness = (r * 299 + g * 587 + b * 114) // 1000 # total_brightness += brightness # pixel_count += 1 # if pixel_count == 0: # return 0 # average_brightness = total_brightness // pixel_count # return average_brightness