diff --git a/.gitignore b/.gitignore index 2257a29d..43aed9db 100644 --- a/.gitignore +++ b/.gitignore @@ -30,4 +30,7 @@ custom_nodes models/ flux_loras/ custom_nodes/ -pysssss-workflows/ \ No newline at end of file +pysssss-workflows/ + +comfy-venv +comfy_venv_3.11 \ No newline at end of file diff --git a/execution.py b/execution.py index 82a66bde..a46c35ad 100644 --- a/execution.py +++ b/execution.py @@ -486,7 +486,7 @@ class PromptExecutor: self.server.client_id = extra_data["client_id"] if "ws_id" in extra_data: self.memedeck_worker.ws_id = extra_data["ws_id"] - self.memedeck_worker.websocket_node_id = extra_data["websocket_node_id"] + self.memedeck_worker.end_node_id = extra_data["end_node_id"] else: self.server.client_id = None self.memedeck_worker.ws_id = None diff --git a/memedeck.py b/memedeck.py index 1c2cb31f..fa0d1b84 100644 --- a/memedeck.py +++ b/memedeck.py @@ -59,6 +59,12 @@ class MemedeckWorker: self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue' self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2' self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383' + + self.training_only = os.getenv('TRAINING_ONLY') or False + + if self.training_only: + self.queue_name = 'training-queue' + print(f"[memedeck]: training only mode enabled") # Internal job queue self.internal_job_queue = asyncio.Queue() @@ -98,8 +104,9 @@ class MemedeckWorker: self.connection = connection # Open the first channel self.connection.channel(on_open_callback=self.on_channel_open) - # Open the second channel - self.connection.channel(on_open_callback=self.on_faceswap_channel_open) + if not self.training_only: + # Open the second channel + self.connection.channel(on_open_callback=self.on_faceswap_channel_open) def on_channel_open(self, channel): self.channel = channel @@ -130,31 +137,36 @@ class MemedeckWorker: # Execute the task prompt = payload["nodes"] valid = self.validate_prompt(prompt) + + workflow = 'training' if self.training_only else 'faceswap' if routing_key == 'faceswap-queue' else 'generation' # Prepare task_info prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] routing_key = method.routing_key task_info = { - "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation', + "workflow": workflow, "prompt_id": prompt_id, "prompt": prompt, "outputs_to_execute": outputs_to_execute, "client_id": "memedeck-1", "is_memedeck": True, - "websocket_node_id": None, + "end_node_id": None, "ws_id": payload["source_ws_id"], - "context": payload["req_ctx"], + "context": payload["req_ctx"] if "req_ctx" in payload else {}, "current_node": None, "current_progress": 0, "delivery_tag": method.delivery_tag, "task_status": "waiting", } - # Find the websocket_node_id + # Find the end_node_id for node in prompt: if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": - task_info['websocket_node_id'] = node + task_info['end_node_id'] = node + break + if self.training_only and isinstance(prompt[node], dict) and prompt[node].get("class_type") == "FluxTrainEnd": + task_info['end_node_id'] = node break if valid[0]: @@ -163,60 +175,6 @@ class MemedeckWorker: else: channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # Unack the message - # def on_channel_open(self, channel): - # self.channel = channel - - # # only consume one message at a time - # self.channel.basic_qos(prefetch_size=0, prefetch_count=1) - # self.channel.queue_declare(queue=self.queue_name, durable=True) - # self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) - # # declare another queue - # self.channel.queue_declare(queue='faceswap-queue', durable=True) - # self.channel.basic_consume(queue='faceswap-queue', on_message_callback=self.on_message_received) - - # def on_message_received(self, channel, method, properties, body): - # decoded_string = body.decode('utf-8') - # json_object = json.loads(decoded_string) - # payload = json_object[1] - - # # execute the task - # prompt = payload["nodes"] - # valid = self.validate_prompt(prompt) - - # # Prepare task_info - # prompt_id = str(uuid.uuid4()) - # outputs_to_execute = valid[2] - # # get the routing key from the method - # routing_key = method.routing_key - # task_info = { - # "workflow": 'faceswap' if routing_key == 'faceswap-queue' else 'generation', - # "prompt_id": prompt_id, - # "prompt": prompt, - # "outputs_to_execute": outputs_to_execute, - # "client_id": "memedeck-1", - # "is_memedeck": True, - # "websocket_node_id": None, - # "ws_id": payload["source_ws_id"], - # "context": payload["req_ctx"], - # "current_node": None, - # "current_progress": 0, - # "delivery_tag": method.delivery_tag, - # "task_status": "waiting", - # } - - # # Find the websocket_node_id - # for node in prompt: - # if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": - # task_info['websocket_node_id'] = node - # break - - # if valid[0]: - # # Enqueue the task into the internal job queue - # self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info)) - # # self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}") - # else: - # channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message - # -------------------------------------------------- # Internal job queue # -------------------------------------------------- @@ -234,14 +192,15 @@ class MemedeckWorker: self.prompt_queue.put((0, prompt_id, prompt, { "client_id": task_info["client_id"], 'is_memedeck': task_info['is_memedeck'], - 'websocket_node_id': task_info['websocket_node_id'], + 'end_node_id': task_info['end_node_id'], 'ws_id': task_info['ws_id'], 'context': task_info['context'] }, task_info['outputs_to_execute'])) - if 'faceswap_strength' not in task_info['context']['prompt_config']: - # pretty print the prompt config - self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}") + if task_info['workflow'] != 'training': + if 'faceswap_strength' not in task_info['context']['prompt_config']: + # pretty print the prompt config + self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}") # Wait until the current task is completed await self.wait_for_task_completion(ws_id) # Task is done @@ -295,7 +254,12 @@ class MemedeckWorker: if not task: self.logger.warning(f"Received event {event} for unknown sid: {sid}") return - + + if task['workflow'] == 'training': + return self.handle_training_send(event, task, sid) + + + # this logic is for generation and faceswap (todo move to separate function) if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: await self.send_preview( data, @@ -305,7 +269,7 @@ class MemedeckWorker: workflow=task['workflow'] ) else: - # Send JSON data / text data + # Send JSON data / text data if event == "executing": task['current_node'] = data['node'] if task['workflow'] == 'faceswap' and task["task_status"] == "waiting": @@ -319,17 +283,17 @@ class MemedeckWorker: # self.logger.info(f"[memedeck]: faceswap executing: {data}") task["task_status"] = "executing" elif event == "progress": - if task['current_node'] == task['websocket_node_id']: + if task['current_node'] == task['end_node_id']: # If the node is the websocket node, then set the progress to 100 task['current_progress'] = 100 else: # If the node is not the websocket node, then set the progress based on the node's progress task['current_progress'] = (data['value'] / data['max']) * 100 - if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']: + if task['current_progress'] == 100 and task['current_node'] != task['end_node_id']: # In case the progress is 100 but the node is not the websocket node, set progress to 95 task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event - if data['value'] == 1: + if data['value'] == 1 and task['workflow'] != 'training': # If the value is 1, send started to API start_data = { "ws_id": task['ws_id'], @@ -341,9 +305,22 @@ class MemedeckWorker: elif event == "status": self.logger.info(f"[memedeck]: sending status event: {data}") + + elif event == "executed": + # self.logger.info(f"[memedeck]: sending executed event: {data}") + if task['workflow'] == 'training' and task['current_node'] == task['end_node_id']: + self.logger.info(f"[memedeck]: training completed for {sid}") + del self.tasks_by_ws_id[sid] # Update the task in tasks_by_ws_id self.tasks_by_ws_id[sid] = task + + def handle_training_send(self, event, task, sid): + + if event == "executed": + if task['current_node'] == task['end_node_id']: + self.logger.info(f"[memedeck]: training completed for {sid}") + del self.tasks_by_ws_id[sid] async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None): if sid is None: @@ -407,12 +384,16 @@ class MemedeckWorker: if not task: self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}") return - if task['websocket_node_id'] is None: - self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}") + if task['end_node_id'] is None: + self.logger.error(f"[memedeck]: end_node_id is None for {ws_id}") return + + api_endpoint = '/generation/update' if task['workflow'] != 'training' else '/training/update' + self.logger.info(f"[memedeck]: sending to api: {api_endpoint}") + self.logger.info(f"[memedeck]: data: {data}") try: # this request is not sending properly for faceswap - post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data) + post_func = partial(requests.post, f"{self.api_url}{api_endpoint}", json=data) await self.loop.run_in_executor(None, post_func) except Exception as e: self.logger.info(f"[memedeck]: error sending to api: {e}") diff --git a/requirements.txt b/requirements.txt index 26ab9048..2b637000 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,7 @@ pillow azure-storage-blob cairosvg aio_pika + +torchao +insightface +onnxruntime-gpu