mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
changes for training workflow
This commit is contained in:
parent
c1922cb347
commit
a22b1a7895
5
.gitignore
vendored
5
.gitignore
vendored
@ -30,4 +30,7 @@ custom_nodes
|
||||
models/
|
||||
flux_loras/
|
||||
custom_nodes/
|
||||
pysssss-workflows/
|
||||
pysssss-workflows/
|
||||
|
||||
comfy-venv
|
||||
comfy_venv_3.11
|
@ -486,7 +486,7 @@ class PromptExecutor:
|
||||
self.server.client_id = extra_data["client_id"]
|
||||
if "ws_id" in extra_data:
|
||||
self.memedeck_worker.ws_id = extra_data["ws_id"]
|
||||
self.memedeck_worker.websocket_node_id = extra_data["websocket_node_id"]
|
||||
self.memedeck_worker.end_node_id = extra_data["end_node_id"]
|
||||
else:
|
||||
self.server.client_id = None
|
||||
self.memedeck_worker.ws_id = None
|
||||
|
127
memedeck.py
127
memedeck.py
@ -59,6 +59,12 @@ class MemedeckWorker:
|
||||
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
|
||||
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
|
||||
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
|
||||
|
||||
self.training_only = os.getenv('TRAINING_ONLY') or False
|
||||
|
||||
if self.training_only:
|
||||
self.queue_name = 'training-queue'
|
||||
print(f"[memedeck]: training only mode enabled")
|
||||
|
||||
# Internal job queue
|
||||
self.internal_job_queue = asyncio.Queue()
|
||||
@ -98,8 +104,9 @@ class MemedeckWorker:
|
||||
self.connection = connection
|
||||
# Open the first channel
|
||||
self.connection.channel(on_open_callback=self.on_channel_open)
|
||||
# Open the second channel
|
||||
self.connection.channel(on_open_callback=self.on_faceswap_channel_open)
|
||||
if not self.training_only:
|
||||
# Open the second channel
|
||||
self.connection.channel(on_open_callback=self.on_faceswap_channel_open)
|
||||
|
||||
def on_channel_open(self, channel):
|
||||
self.channel = channel
|
||||
@ -130,31 +137,36 @@ class MemedeckWorker:
|
||||
# Execute the task
|
||||
prompt = payload["nodes"]
|
||||
valid = self.validate_prompt(prompt)
|
||||
|
||||
workflow = 'training' if self.training_only else 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
|
||||
|
||||
# Prepare task_info
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
routing_key = method.routing_key
|
||||
task_info = {
|
||||
"workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
|
||||
"workflow": workflow,
|
||||
"prompt_id": prompt_id,
|
||||
"prompt": prompt,
|
||||
"outputs_to_execute": outputs_to_execute,
|
||||
"client_id": "memedeck-1",
|
||||
"is_memedeck": True,
|
||||
"websocket_node_id": None,
|
||||
"end_node_id": None,
|
||||
"ws_id": payload["source_ws_id"],
|
||||
"context": payload["req_ctx"],
|
||||
"context": payload["req_ctx"] if "req_ctx" in payload else {},
|
||||
"current_node": None,
|
||||
"current_progress": 0,
|
||||
"delivery_tag": method.delivery_tag,
|
||||
"task_status": "waiting",
|
||||
}
|
||||
|
||||
# Find the websocket_node_id
|
||||
# Find the end_node_id
|
||||
for node in prompt:
|
||||
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
|
||||
task_info['websocket_node_id'] = node
|
||||
task_info['end_node_id'] = node
|
||||
break
|
||||
if self.training_only and isinstance(prompt[node], dict) and prompt[node].get("class_type") == "FluxTrainEnd":
|
||||
task_info['end_node_id'] = node
|
||||
break
|
||||
|
||||
if valid[0]:
|
||||
@ -163,60 +175,6 @@ class MemedeckWorker:
|
||||
else:
|
||||
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # Unack the message
|
||||
|
||||
# def on_channel_open(self, channel):
|
||||
# self.channel = channel
|
||||
|
||||
# # only consume one message at a time
|
||||
# self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
|
||||
# self.channel.queue_declare(queue=self.queue_name, durable=True)
|
||||
# self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
|
||||
# # declare another queue
|
||||
# self.channel.queue_declare(queue='faceswap-queue', durable=True)
|
||||
# self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received)
|
||||
|
||||
# def on_message_received(self, channel, method, properties, body):
|
||||
# decoded_string = body.decode('utf-8')
|
||||
# json_object = json.loads(decoded_string)
|
||||
# payload = json_object[1]
|
||||
|
||||
# # execute the task
|
||||
# prompt = payload["nodes"]
|
||||
# valid = self.validate_prompt(prompt)
|
||||
|
||||
# # Prepare task_info
|
||||
# prompt_id = str(uuid.uuid4())
|
||||
# outputs_to_execute = valid[2]
|
||||
# # get the routing key from the method
|
||||
# routing_key = method.routing_key
|
||||
# task_info = {
|
||||
# "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation',
|
||||
# "prompt_id": prompt_id,
|
||||
# "prompt": prompt,
|
||||
# "outputs_to_execute": outputs_to_execute,
|
||||
# "client_id": "memedeck-1",
|
||||
# "is_memedeck": True,
|
||||
# "websocket_node_id": None,
|
||||
# "ws_id": payload["source_ws_id"],
|
||||
# "context": payload["req_ctx"],
|
||||
# "current_node": None,
|
||||
# "current_progress": 0,
|
||||
# "delivery_tag": method.delivery_tag,
|
||||
# "task_status": "waiting",
|
||||
# }
|
||||
|
||||
# # Find the websocket_node_id
|
||||
# for node in prompt:
|
||||
# if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
|
||||
# task_info['websocket_node_id'] = node
|
||||
# break
|
||||
|
||||
# if valid[0]:
|
||||
# # Enqueue the task into the internal job queue
|
||||
# self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
|
||||
# # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
|
||||
# else:
|
||||
# channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
|
||||
|
||||
# --------------------------------------------------
|
||||
# Internal job queue
|
||||
# --------------------------------------------------
|
||||
@ -234,14 +192,15 @@ class MemedeckWorker:
|
||||
self.prompt_queue.put((0, prompt_id, prompt, {
|
||||
"client_id": task_info["client_id"],
|
||||
'is_memedeck': task_info['is_memedeck'],
|
||||
'websocket_node_id': task_info['websocket_node_id'],
|
||||
'end_node_id': task_info['end_node_id'],
|
||||
'ws_id': task_info['ws_id'],
|
||||
'context': task_info['context']
|
||||
}, task_info['outputs_to_execute']))
|
||||
|
||||
if 'faceswap_strength' not in task_info['context']['prompt_config']:
|
||||
# pretty print the prompt config
|
||||
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
|
||||
if task_info['workflow'] != 'training':
|
||||
if 'faceswap_strength' not in task_info['context']['prompt_config']:
|
||||
# pretty print the prompt config
|
||||
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
|
||||
# Wait until the current task is completed
|
||||
await self.wait_for_task_completion(ws_id)
|
||||
# Task is done
|
||||
@ -295,7 +254,12 @@ class MemedeckWorker:
|
||||
if not task:
|
||||
self.logger.warning(f"Received event {event} for unknown sid: {sid}")
|
||||
return
|
||||
|
||||
|
||||
if task['workflow'] == 'training':
|
||||
return self.handle_training_send(event, task, sid)
|
||||
|
||||
|
||||
# this logic is for generation and faceswap (todo move to separate function)
|
||||
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_preview(
|
||||
data,
|
||||
@ -305,7 +269,7 @@ class MemedeckWorker:
|
||||
workflow=task['workflow']
|
||||
)
|
||||
else:
|
||||
# Send JSON data / text data
|
||||
# Send JSON data / text data
|
||||
if event == "executing":
|
||||
task['current_node'] = data['node']
|
||||
if task['workflow'] == 'faceswap' and task["task_status"] == "waiting":
|
||||
@ -319,17 +283,17 @@ class MemedeckWorker:
|
||||
# self.logger.info(f"[memedeck]: faceswap executing: {data}")
|
||||
task["task_status"] = "executing"
|
||||
elif event == "progress":
|
||||
if task['current_node'] == task['websocket_node_id']:
|
||||
if task['current_node'] == task['end_node_id']:
|
||||
# If the node is the websocket node, then set the progress to 100
|
||||
task['current_progress'] = 100
|
||||
else:
|
||||
# If the node is not the websocket node, then set the progress based on the node's progress
|
||||
task['current_progress'] = (data['value'] / data['max']) * 100
|
||||
if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']:
|
||||
if task['current_progress'] == 100 and task['current_node'] != task['end_node_id']:
|
||||
# In case the progress is 100 but the node is not the websocket node, set progress to 95
|
||||
task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event
|
||||
|
||||
if data['value'] == 1:
|
||||
if data['value'] == 1 and task['workflow'] != 'training':
|
||||
# If the value is 1, send started to API
|
||||
start_data = {
|
||||
"ws_id": task['ws_id'],
|
||||
@ -341,9 +305,22 @@ class MemedeckWorker:
|
||||
|
||||
elif event == "status":
|
||||
self.logger.info(f"[memedeck]: sending status event: {data}")
|
||||
|
||||
elif event == "executed":
|
||||
# self.logger.info(f"[memedeck]: sending executed event: {data}")
|
||||
if task['workflow'] == 'training' and task['current_node'] == task['end_node_id']:
|
||||
self.logger.info(f"[memedeck]: training completed for {sid}")
|
||||
del self.tasks_by_ws_id[sid]
|
||||
|
||||
# Update the task in tasks_by_ws_id
|
||||
self.tasks_by_ws_id[sid] = task
|
||||
|
||||
def handle_training_send(self, event, task, sid):
|
||||
|
||||
if event == "executed":
|
||||
if task['current_node'] == task['end_node_id']:
|
||||
self.logger.info(f"[memedeck]: training completed for {sid}")
|
||||
del self.tasks_by_ws_id[sid]
|
||||
|
||||
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
|
||||
if sid is None:
|
||||
@ -407,12 +384,16 @@ class MemedeckWorker:
|
||||
if not task:
|
||||
self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}")
|
||||
return
|
||||
if task['websocket_node_id'] is None:
|
||||
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
|
||||
if task['end_node_id'] is None:
|
||||
self.logger.error(f"[memedeck]: end_node_id is None for {ws_id}")
|
||||
return
|
||||
|
||||
api_endpoint = '/generation/update' if task['workflow'] != 'training' else '/training/update'
|
||||
self.logger.info(f"[memedeck]: sending to api: {api_endpoint}")
|
||||
self.logger.info(f"[memedeck]: data: {data}")
|
||||
try:
|
||||
# this request is not sending properly for faceswap
|
||||
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
|
||||
post_func = partial(requests.post, f"{self.api_url}{api_endpoint}", json=data)
|
||||
await self.loop.run_in_executor(None, post_func)
|
||||
except Exception as e:
|
||||
self.logger.info(f"[memedeck]: error sending to api: {e}")
|
||||
|
@ -30,3 +30,7 @@ pillow
|
||||
azure-storage-blob
|
||||
cairosvg
|
||||
aio_pika
|
||||
|
||||
torchao
|
||||
insightface
|
||||
onnxruntime-gpu
|
||||
|
Loading…
Reference in New Issue
Block a user