changes for training workflow

This commit is contained in:
drunkplato 2024-11-25 14:53:53 +00:00 committed by Ubuntu
parent c1922cb347
commit a22b1a7895
4 changed files with 63 additions and 75 deletions

5
.gitignore vendored
View File

@ -30,4 +30,7 @@ custom_nodes
models/
flux_loras/
custom_nodes/
pysssss-workflows/
pysssss-workflows/
comfy-venv
comfy_venv_3.11

View File

@ -486,7 +486,7 @@ class PromptExecutor:
self.server.client_id = extra_data["client_id"]
if "ws_id" in extra_data:
self.memedeck_worker.ws_id = extra_data["ws_id"]
self.memedeck_worker.websocket_node_id = extra_data["websocket_node_id"]
self.memedeck_worker.end_node_id = extra_data["end_node_id"]
else:
self.server.client_id = None
self.memedeck_worker.ws_id = None

View File

@ -59,6 +59,12 @@ class MemedeckWorker:
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
self.training_only = os.getenv('TRAINING_ONLY') or False
if self.training_only:
self.queue_name = 'training-queue'
print(f"[memedeck]: training only mode enabled")
# Internal job queue
self.internal_job_queue = asyncio.Queue()
@ -98,8 +104,9 @@ class MemedeckWorker:
self.connection = connection
# Open the first channel
self.connection.channel(on_open_callback=self.on_channel_open)
# Open the second channel
self.connection.channel(on_open_callback=self.on_faceswap_channel_open)
if not self.training_only:
# Open the second channel
self.connection.channel(on_open_callback=self.on_faceswap_channel_open)
def on_channel_open(self, channel):
self.channel = channel
@ -130,31 +137,36 @@ class MemedeckWorker:
# Execute the task
prompt = payload["nodes"]
valid = self.validate_prompt(prompt)
workflow = 'training' if self.training_only else 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
# Prepare task_info
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
routing_key = method.routing_key
task_info = {
"workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
"workflow": workflow,
"prompt_id": prompt_id,
"prompt": prompt,
"outputs_to_execute": outputs_to_execute,
"client_id": "memedeck-1",
"is_memedeck": True,
"websocket_node_id": None,
"end_node_id": None,
"ws_id": payload["source_ws_id"],
"context": payload["req_ctx"],
"context": payload["req_ctx"] if "req_ctx" in payload else {},
"current_node": None,
"current_progress": 0,
"delivery_tag": method.delivery_tag,
"task_status": "waiting",
}
# Find the websocket_node_id
# Find the end_node_id
for node in prompt:
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
task_info['websocket_node_id'] = node
task_info['end_node_id'] = node
break
if self.training_only and isinstance(prompt[node], dict) and prompt[node].get("class_type") == "FluxTrainEnd":
task_info['end_node_id'] = node
break
if valid[0]:
@ -163,60 +175,6 @@ class MemedeckWorker:
else:
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # Unack the message
# def on_channel_open(self, channel):
# self.channel = channel
# # only consume one message at a time
# self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
# self.channel.queue_declare(queue=self.queue_name, durable=True)
# self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
# # declare another queue
# self.channel.queue_declare(queue='faceswap-queue', durable=True)
# self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received)
# def on_message_received(self, channel, method, properties, body):
# decoded_string = body.decode('utf-8')
# json_object = json.loads(decoded_string)
# payload = json_object[1]
# # execute the task
# prompt = payload["nodes"]
# valid = self.validate_prompt(prompt)
# # Prepare task_info
# prompt_id = str(uuid.uuid4())
# outputs_to_execute = valid[2]
# # get the routing key from the method
# routing_key = method.routing_key
# task_info = {
# "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
# "prompt_id": prompt_id,
# "prompt": prompt,
# "outputs_to_execute": outputs_to_execute,
# "client_id": "memedeck-1",
# "is_memedeck": True,
# "websocket_node_id": None,
# "ws_id": payload["source_ws_id"],
# "context": payload["req_ctx"],
# "current_node": None,
# "current_progress": 0,
# "delivery_tag": method.delivery_tag,
# "task_status": "waiting",
# }
# # Find the websocket_node_id
# for node in prompt:
# if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
# task_info['websocket_node_id'] = node
# break
# if valid[0]:
# # Enqueue the task into the internal job queue
# self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
# # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
# else:
# channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
# --------------------------------------------------
# Internal job queue
# --------------------------------------------------
@ -234,14 +192,15 @@ class MemedeckWorker:
self.prompt_queue.put((0, prompt_id, prompt, {
"client_id": task_info["client_id"],
'is_memedeck': task_info['is_memedeck'],
'websocket_node_id': task_info['websocket_node_id'],
'end_node_id': task_info['end_node_id'],
'ws_id': task_info['ws_id'],
'context': task_info['context']
}, task_info['outputs_to_execute']))
if 'faceswap_strength' not in task_info['context']['prompt_config']:
# pretty print the prompt config
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
if task_info['workflow'] != 'training':
if 'faceswap_strength' not in task_info['context']['prompt_config']:
# pretty print the prompt config
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
# Wait until the current task is completed
await self.wait_for_task_completion(ws_id)
# Task is done
@ -295,7 +254,12 @@ class MemedeckWorker:
if not task:
self.logger.warning(f"Received event {event} for unknown sid: {sid}")
return
if task['workflow'] == 'training':
return self.handle_training_send(event, task, sid)
# this logic is for generation and faceswap (todo move to separate function)
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_preview(
data,
@ -305,7 +269,7 @@ class MemedeckWorker:
workflow=task['workflow']
)
else:
# Send JSON data / text data
# Send JSON data / text data
if event == "executing":
task['current_node'] = data['node']
if task['workflow'] == 'faceswap' and task["task_status"] == "waiting":
@ -319,17 +283,17 @@ class MemedeckWorker:
# self.logger.info(f"[memedeck]: faceswap executing: {data}")
task["task_status"] = "executing"
elif event == "progress":
if task['current_node'] == task['websocket_node_id']:
if task['current_node'] == task['end_node_id']:
# If the node is the websocket node, then set the progress to 100
task['current_progress'] = 100
else:
# If the node is not the websocket node, then set the progress based on the node's progress
task['current_progress'] = (data['value'] / data['max']) * 100
if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']:
if task['current_progress'] == 100 and task['current_node'] != task['end_node_id']:
# In case the progress is 100 but the node is not the websocket node, set progress to 95
task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event
if data['value'] == 1:
if data['value'] == 1 and task['workflow'] != 'training':
# If the value is 1, send started to API
start_data = {
"ws_id": task['ws_id'],
@ -341,9 +305,22 @@ class MemedeckWorker:
elif event == "status":
self.logger.info(f"[memedeck]: sending status event: {data}")
elif event == "executed":
# self.logger.info(f"[memedeck]: sending executed event: {data}")
if task['workflow'] == 'training' and task['current_node'] == task['end_node_id']:
self.logger.info(f"[memedeck]: training completed for {sid}")
del self.tasks_by_ws_id[sid]
# Update the task in tasks_by_ws_id
self.tasks_by_ws_id[sid] = task
def handle_training_send(self, event, task, sid):
if event == "executed":
if task['current_node'] == task['end_node_id']:
self.logger.info(f"[memedeck]: training completed for {sid}")
del self.tasks_by_ws_id[sid]
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
if sid is None:
@ -407,12 +384,16 @@ class MemedeckWorker:
if not task:
self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}")
return
if task['websocket_node_id'] is None:
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
if task['end_node_id'] is None:
self.logger.error(f"[memedeck]: end_node_id is None for {ws_id}")
return
api_endpoint = '/generation/update' if task['workflow'] != 'training' else '/training/update'
self.logger.info(f"[memedeck]: sending to api: {api_endpoint}")
self.logger.info(f"[memedeck]: data: {data}")
try:
# this request is not sending properly for faceswap
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
post_func = partial(requests.post, f"{self.api_url}{api_endpoint}", json=data)
await self.loop.run_in_executor(None, post_func)
except Exception as e:
self.logger.info(f"[memedeck]: error sending to api: {e}")

View File

@ -30,3 +30,7 @@ pillow
azure-storage-blob
cairosvg
aio_pika
torchao
insightface
onnxruntime-gpu