mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-01 21:07:20 +08:00
:qadded faceswap support
This commit is contained in:
parent
f93fe6b3fc
commit
439a914562
3
.gitignore
vendored
3
.gitignore
vendored
@ -29,4 +29,5 @@ models
|
|||||||
custom_nodes
|
custom_nodes
|
||||||
models/
|
models/
|
||||||
flux_loras/
|
flux_loras/
|
||||||
custom_nodes/
|
custom_nodes/
|
||||||
|
pysssss-workflows/
|
@ -534,7 +534,7 @@ class PromptExecutor:
|
|||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id, 'ws_id': self.memedeck_worker.ws_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
|
4
main.py
4
main.py
@ -58,8 +58,8 @@ def apply_custom_paths():
|
|||||||
# ---------------------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------------------
|
||||||
from memedeck import MemedeckWorker
|
from memedeck import MemedeckWorker
|
||||||
|
|
||||||
import sys
|
# import sys
|
||||||
sys.stdout = open(os.devnull, 'w') # disable all print statements
|
# sys.stdout = open(os.devnull, 'w') # disable all print statements
|
||||||
# ---------------------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------------------
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
|
55
memedeck.py
55
memedeck.py
@ -55,6 +55,7 @@ class MemedeckWorker:
|
|||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
|
|
||||||
|
self.ws_id = None
|
||||||
self.http_client = None
|
self.http_client = None
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.validate_prompt = None
|
self.validate_prompt = None
|
||||||
@ -87,6 +88,9 @@ class MemedeckWorker:
|
|||||||
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
|
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
|
||||||
self.channel.queue_declare(queue=self.queue_name, durable=True)
|
self.channel.queue_declare(queue=self.queue_name, durable=True)
|
||||||
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False)
|
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False)
|
||||||
|
# declare another queue
|
||||||
|
self.channel.queue_declare(queue='faceswap-queue', durable=True)
|
||||||
|
self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received, auto_ack=False)
|
||||||
|
|
||||||
def start(self, prompt_queue, validate_prompt):
|
def start(self, prompt_queue, validate_prompt):
|
||||||
self.prompt_queue = prompt_queue
|
self.prompt_queue = prompt_queue
|
||||||
@ -117,7 +121,11 @@ class MemedeckWorker:
|
|||||||
# Prepare task_info
|
# Prepare task_info
|
||||||
prompt_id = str(uuid.uuid4())
|
prompt_id = str(uuid.uuid4())
|
||||||
outputs_to_execute = valid[2]
|
outputs_to_execute = valid[2]
|
||||||
|
# get the routing key from the method
|
||||||
|
routing_key = method.routing_key
|
||||||
|
self.logger.info(f"[memedeck]: routing_key: {routing_key}")
|
||||||
task_info = {
|
task_info = {
|
||||||
|
"workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"outputs_to_execute": outputs_to_execute,
|
"outputs_to_execute": outputs_to_execute,
|
||||||
@ -141,7 +149,7 @@ class MemedeckWorker:
|
|||||||
if valid[0]:
|
if valid[0]:
|
||||||
# Enqueue the task into the internal job queue
|
# Enqueue the task into the internal job queue
|
||||||
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
|
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
|
||||||
self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
|
# self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
|
||||||
else:
|
else:
|
||||||
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
|
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
|
||||||
|
|
||||||
@ -163,11 +171,14 @@ class MemedeckWorker:
|
|||||||
'ws_id': task_info['ws_id'],
|
'ws_id': task_info['ws_id'],
|
||||||
'context': task_info['context']
|
'context': task_info['context']
|
||||||
}, task_info['outputs_to_execute']))
|
}, task_info['outputs_to_execute']))
|
||||||
|
|
||||||
# Acknowledge the message
|
# Acknowledge the message
|
||||||
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
|
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
|
||||||
self.logger.info(f"[memedeck]: Acked task {prompt_id} {ws_id}")
|
|
||||||
|
if 'faceswap_strength' not in task_info['context']['prompt_config']:
|
||||||
|
# pretty print the prompt config
|
||||||
|
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
|
||||||
|
|
||||||
self.logger.info(f"[memedeck]: Started processing prompt {prompt_id}")
|
|
||||||
# Wait until the current task is completed
|
# Wait until the current task is completed
|
||||||
await self.wait_for_task_completion(ws_id)
|
await self.wait_for_task_completion(ws_id)
|
||||||
# Task is done
|
# Task is done
|
||||||
@ -219,12 +230,22 @@ class MemedeckWorker:
|
|||||||
data,
|
data,
|
||||||
sid=sid,
|
sid=sid,
|
||||||
progress=task['current_progress'],
|
progress=task['current_progress'],
|
||||||
context=task['context']
|
context=task['context'],
|
||||||
|
workflow=task['workflow']
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Send JSON data / text data
|
# Send JSON data / text data
|
||||||
if event == "executing":
|
if event == "executing":
|
||||||
task['current_node'] = data['node']
|
task['current_node'] = data['node']
|
||||||
|
if task['workflow'] == 'faceswap' and task["task_status"] == "waiting":
|
||||||
|
start_data = {
|
||||||
|
"ws_id": task['ws_id'],
|
||||||
|
"status": "started",
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
await self.send_to_api(start_data)
|
||||||
|
|
||||||
|
# self.logger.info(f"[memedeck]: faceswap executing: {data}")
|
||||||
task["task_status"] = "executing"
|
task["task_status"] = "executing"
|
||||||
elif event == "progress":
|
elif event == "progress":
|
||||||
if task['current_node'] == task['websocket_node_id']:
|
if task['current_node'] == task['websocket_node_id']:
|
||||||
@ -253,7 +274,7 @@ class MemedeckWorker:
|
|||||||
# Update the task in tasks_by_ws_id
|
# Update the task in tasks_by_ws_id
|
||||||
self.tasks_by_ws_id[sid] = task
|
self.tasks_by_ws_id[sid] = task
|
||||||
|
|
||||||
async def send_preview(self, image_data, sid=None, progress=None, context=None):
|
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
|
||||||
if sid is None:
|
if sid is None:
|
||||||
self.logger.warning("Received preview without sid")
|
self.logger.warning("Received preview without sid")
|
||||||
return
|
return
|
||||||
@ -284,20 +305,29 @@ class MemedeckWorker:
|
|||||||
bytesIO = BytesIO()
|
bytesIO = BytesIO()
|
||||||
image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1)
|
image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1)
|
||||||
preview_bytes = bytesIO.getvalue()
|
preview_bytes = bytesIO.getvalue()
|
||||||
|
|
||||||
|
kind = "image_generating" if progress < 100 else "image_generated"
|
||||||
|
|
||||||
ai_queue_progress = {
|
ai_queue_progress = {
|
||||||
"ws_id": sid,
|
"ws_id": sid,
|
||||||
"kind": "image_generating" if progress < 100 else "image_generated",
|
"kind": kind,
|
||||||
"data": list(preview_bytes),
|
"data": list(preview_bytes),
|
||||||
"progress": int(progress),
|
"progress": int(progress),
|
||||||
"context": context
|
"context": context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# set the kind to faceswap_generated if workflow is faceswap
|
||||||
|
if workflow == 'faceswap':
|
||||||
|
ai_queue_progress['kind'] = "faceswap_generated"
|
||||||
|
del ai_queue_progress['progress']
|
||||||
|
# dont print the data field without deleting it
|
||||||
|
self.logger.info(f"[memedeck]: sending faceswap result")
|
||||||
|
|
||||||
await self.send_to_api(ai_queue_progress)
|
await self.send_to_api(ai_queue_progress)
|
||||||
|
|
||||||
if progress == 100:
|
if progress == 100 or workflow == 'faceswap':
|
||||||
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
|
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
|
||||||
self.logger.info(f"[memedeck]: Task {sid} completed")
|
# self.logger.info(f"[memedeck]: Task {sid} completed")
|
||||||
|
|
||||||
async def send_to_api(self, data):
|
async def send_to_api(self, data):
|
||||||
ws_id = data.get('ws_id')
|
ws_id = data.get('ws_id')
|
||||||
@ -312,10 +342,11 @@ class MemedeckWorker:
|
|||||||
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
|
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
# this request is not sending properly for faceswap
|
||||||
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
|
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
|
||||||
await self.loop.run_in_executor(None, post_func)
|
await self.loop.run_in_executor(None, post_func)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"[memedeck]: error sending to api: {e}")
|
self.logger.info(f"[memedeck]: error sending to api: {e}")
|
||||||
|
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
# MemedeckAzureStorage
|
# MemedeckAzureStorage
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user