:qadded faceswap support

This commit is contained in:
drunkplato 2024-10-16 18:13:23 +00:00 committed by Ubuntu
parent f93fe6b3fc
commit 439a914562
5 changed files with 48 additions and 17 deletions

3
.gitignore vendored
View File

@ -29,4 +29,5 @@ models
custom_nodes
models/
flux_loras/
custom_nodes/
custom_nodes/
pysssss-workflows/

View File

@ -534,7 +534,7 @@ class PromptExecutor:
execution_list.complete_node_execution()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id, 'ws_id': self.memedeck_worker.ws_id }, broadcast=False)
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}

View File

@ -58,8 +58,8 @@ def apply_custom_paths():
# ---------------------------------------------------------------------------------------
from memedeck import MemedeckWorker
import sys
sys.stdout = open(os.devnull, 'w') # disable all print statements
# import sys
# sys.stdout = open(os.devnull, 'w') # disable all print statements
# ---------------------------------------------------------------------------------------
def execute_prestartup_script():

View File

@ -55,6 +55,7 @@ class MemedeckWorker:
self.loop = loop
self.messages = asyncio.Queue()
self.ws_id = None
self.http_client = None
self.prompt_queue = None
self.validate_prompt = None
@ -87,6 +88,9 @@ class MemedeckWorker:
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
self.channel.queue_declare(queue=self.queue_name, durable=True)
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False)
# declare another queue
self.channel.queue_declare(queue='faceswap-queue', durable=True)
self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received, auto_ack=False)
def start(self, prompt_queue, validate_prompt):
self.prompt_queue = prompt_queue
@ -117,7 +121,11 @@ class MemedeckWorker:
# Prepare task_info
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
# get the routing key from the method
routing_key = method.routing_key
self.logger.info(f"[memedeck]: routing_key: {routing_key}")
task_info = {
"workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
"prompt_id": prompt_id,
"prompt": prompt,
"outputs_to_execute": outputs_to_execute,
@ -141,7 +149,7 @@ class MemedeckWorker:
if valid[0]:
# Enqueue the task into the internal job queue
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
# self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
else:
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
@ -163,11 +171,14 @@ class MemedeckWorker:
'ws_id': task_info['ws_id'],
'context': task_info['context']
}, task_info['outputs_to_execute']))
# Acknowledge the message
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
self.logger.info(f"[memedeck]: Acked task {prompt_id} {ws_id}")
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
if 'faceswap_strength' not in task_info['context']['prompt_config']:
# pretty print the prompt config
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
self.logger.info(f"[memedeck]: Started processing prompt {prompt_id}")
# Wait until the current task is completed
await self.wait_for_task_completion(ws_id)
# Task is done
@ -219,12 +230,22 @@ class MemedeckWorker:
data,
sid=sid,
progress=task['current_progress'],
context=task['context']
context=task['context'],
workflow=task['workflow']
)
else:
# Send JSON data / text data
if event == "executing":
task['current_node'] = data['node']
if task['workflow'] == 'faceswap' and task["task_status"] == "waiting":
start_data = {
"ws_id": task['ws_id'],
"status": "started",
"info": None,
}
await self.send_to_api(start_data)
# self.logger.info(f"[memedeck]: faceswap executing: {data}")
task["task_status"] = "executing"
elif event == "progress":
if task['current_node'] == task['websocket_node_id']:
@ -253,7 +274,7 @@ class MemedeckWorker:
# Update the task in tasks_by_ws_id
self.tasks_by_ws_id[sid] = task
async def send_preview(self, image_data, sid=None, progress=None, context=None):
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
if sid is None:
self.logger.warning("Received preview without sid")
return
@ -284,20 +305,29 @@ class MemedeckWorker:
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1)
preview_bytes = bytesIO.getvalue()
kind = "image_generating" if progress < 100 else "image_generated"
ai_queue_progress = {
"ws_id": sid,
"kind": "image_generating" if progress < 100 else "image_generated",
"kind": kind,
"data": list(preview_bytes),
"progress": int(progress),
"context": context
}
# set the kind to faceswap_generated if workflow is faceswap
if workflow == 'faceswap':
ai_queue_progress['kind'] = "faceswap_generated"
del ai_queue_progress['progress']
# dont print the data field without deleting it
self.logger.info(f"[memedeck]: sending faceswap result")
await self.send_to_api(ai_queue_progress)
if progress == 100:
if progress == 100 or workflow == 'faceswap':
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
self.logger.info(f"[memedeck]: Task {sid} completed")
# self.logger.info(f"[memedeck]: Task {sid} completed")
async def send_to_api(self, data):
ws_id = data.get('ws_id')
@ -312,10 +342,11 @@ class MemedeckWorker:
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
return
try:
# this request is not sending properly for faceswap
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
await self.loop.run_in_executor(None, post_func)
except Exception as e:
self.logger.error(f"[memedeck]: error sending to api: {e}")
self.logger.info(f"[memedeck]: error sending to api: {e}")
# --------------------------------------------------------------------------
# MemedeckAzureStorage

File diff suppressed because one or more lines are too long