diff --git a/.gitignore b/.gitignore index 61881b8a..520190e4 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,12 @@ venv/ *.log web_custom_versions/ .DS_Store + +.env + + +models +custom_nodes +models/ +flux_loras/ +custom_nodes/ \ No newline at end of file diff --git a/MEMEDECK.md b/MEMEDECK.md new file mode 100644 index 00000000..aa8f98f9 --- /dev/null +++ b/MEMEDECK.md @@ -0,0 +1,50 @@ +# Commands to set up the environment. + +## Enable MIGs on NVIDA GPU + + +https://www.seimaxim.com/kb/gpu/nvidia-a100-mig-cheat-sheat + +```zsh +sudo nvidia-smi -i 0 -mig 1 +sudo nvidia-smi mig -cgi 9,9 -C +``` + +Start comfy with MIGs + +```zsh +CUDA_VISIBLE_DEVICES=MIG-0 python main.py --port 5000 --listen 0.0.0.0 --cuda-device 0 --preview-method auto +``` + + +## Tunnel remote server port to local machine port + +### On the server + +```zsh +sudo nano /etc/ssh/sshd_config + +# uncomment this line in the sshd_config file +GatewayPorts yes +``` +Then restart the machine. + +## On your local machine +```zsh +sudo ssh -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40 +``` + +## install autossh (maybe not needed) +```zsh +brew install autossh +autossh -M 0 -o "ServerAliveInterval 30" -o "ServerAliveCountMax 3" -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40 +``` + + + +## On the server +This is to allow the port to be accessed from the local machine. + +```zsh +sudo iptables -A INPUT -p tcp --dport 9090 -j ACCEPT +``` \ No newline at end of file diff --git a/comfy/lora.py b/comfy/lora.py index bc9f3022..f37d2481 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -213,10 +213,9 @@ def load_lora(lora, to_load, log_missing=True): patch_dict[to_load[x]] = ("set", (set_weight,)) loaded_keys.add(set_weight_name) - if log_missing: - for x in lora.keys(): - if x not in loaded_keys: - logging.warning("lora key not loaded: {}".format(x)) + # for x in lora.keys(): + # if x not in loaded_keys: + # logging.warning("lora key not loaded: {}".format(x)) return patch_dict diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example deleted file mode 100644 index 29ab2aa7..00000000 --- a/custom_nodes/example_node.py.example +++ /dev/null @@ -1,155 +0,0 @@ -class Example: - """ - A example node - - Class methods - ------------- - INPUT_TYPES (dict): - Tell the main program input parameters of nodes. - IS_CHANGED: - optional method to control when the node is re executed. - - Attributes - ---------- - RETURN_TYPES (`tuple`): - The type of each element in the output tuple. - RETURN_NAMES (`tuple`): - Optional: The name of each output in the output tuple. - FUNCTION (`str`): - The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() - OUTPUT_NODE ([`bool`]): - If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. - The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. - Assumed to be False if not present. - CATEGORY (`str`): - The category the node should appear in the UI. - DEPRECATED (`bool`): - Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain - functional in existing workflows that use them. - EXPERIMENTAL (`bool`): - Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to - significant changes or removal in future versions. Use with caution in production workflows. - execute(s) -> tuple || None: - The entry point method. The name of this method must be the same as the value of property `FUNCTION`. - For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. - """ - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(s): - """ - Return a dictionary which contains config for all input fields. - Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". - Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. - The type can be a list for selection. - - Returns: `dict`: - - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` - - Value input_fields (`dict`): Contains input fields config: - * Key field_name (`string`): Name of a entry-point method's argument - * Value field_config (`tuple`): - + First value is a string indicate the type of field or a list for selection. - + Second value is a config for type "INT", "STRING" or "FLOAT". - """ - return { - "required": { - "image": ("IMAGE",), - "int_field": ("INT", { - "default": 0, - "min": 0, #Minimum value - "max": 4096, #Maximum value - "step": 64, #Slider's step - "display": "number", # Cosmetic only: display as "number" or "slider" - "lazy": True # Will only be evaluated if check_lazy_status requires it - }), - "float_field": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 10.0, - "step": 0.01, - "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number", - "lazy": True - }), - "print_to_screen": (["enable", "disable"],), - "string_field": ("STRING", { - "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!", - "lazy": True - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - #RETURN_NAMES = ("image_output_name",) - - FUNCTION = "test" - - #OUTPUT_NODE = False - - CATEGORY = "Example" - - def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): - """ - Return a list of input names that need to be evaluated. - - This function will be called if there are any lazy inputs which have not yet been - evaluated. As long as you return at least one field which has not yet been evaluated - (and more exist), this function will be called again once the value of the requested - field is available. - - Any evaluated inputs will be passed as arguments to this function. Any unevaluated - inputs will have the value None. - """ - if print_to_screen == "enable": - return ["int_field", "float_field", "string_field"] - else: - return [] - - def test(self, image, string_field, int_field, float_field, print_to_screen): - if print_to_screen == "enable": - print(f"""Your input contains: - string_field aka input text: {string_field} - int_field: {int_field} - float_field: {float_field} - """) - #do some processing on the image, in this example I just invert it - image = 1.0 - image - return (image,) - - """ - The node will always be re executed if any of the inputs change but - this method can be used to force the node to execute again even when the inputs don't change. - You can make this node return a number or a string. This value will be compared to the one returned the last time the node was - executed, if it is different the node will be executed again. - This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash - changes between executions the LoadImage node is executed again. - """ - #@classmethod - #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): - # return "" - -# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension -# WEB_DIRECTORY = "./somejs" - - -# Add custom API routes, using router -from aiohttp import web -from server import PromptServer - -@PromptServer.instance.routes.get("/hello") -async def get_hello(request): - return web.json_response("hello") - - -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Example": Example -} - -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Example": "Example Node" -} diff --git a/custom_nodes/websocket_image_save.py b/custom_nodes/websocket_image_save.py deleted file mode 100644 index 15f87f9f..00000000 --- a/custom_nodes/websocket_image_save.py +++ /dev/null @@ -1,44 +0,0 @@ -from PIL import Image -import numpy as np -import comfy.utils -import time - -#You can use this node to save full size images through the websocket, the -#images will be sent in exactly the same format as the image previews: as -#binary images on the websocket with a 8 byte header indicating the type -#of binary message (first 4 bytes) and the image format (next 4 bytes). - -#Note that no metadata will be put in the images saved with this node. - -class SaveImageWebsocket: - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ),} - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "api/image" - - def save_images(self, images): - pbar = comfy.utils.ProgressBar(images.shape[0]) - step = 0 - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) - step += 1 - - return {} - - @classmethod - def IS_CHANGED(s, images): - return time.time() - -NODE_CLASS_MAPPINGS = { - "SaveImageWebsocket": SaveImageWebsocket, -} diff --git a/execution.py b/execution.py index 2c979205..afeb855a 100644 --- a/execution.py +++ b/execution.py @@ -245,7 +245,7 @@ def format_value(x): else: return str(x) -def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): +def execute(server, memedeck_worker, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -253,10 +253,15 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + is_memedeck = extra_data.get("is_memedeck", False) + if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + if memedeck_worker.ws_id is not None: + memedeck_worker.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, memedeck_worker.ws_id) + # elif is_memedeck: return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -284,10 +289,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp has_subgraph = False else: input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) - if server.client_id is not None: + if server.client_id is not None and not is_memedeck: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - + elif is_memedeck: + memedeck_worker.last_node_id = display_node_id + memedeck_worker.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, memedeck_worker.ws_id) + obj = caches.objects.get(unique_id) if obj is None: obj = class_def() @@ -319,6 +327,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp "current_outputs": [], } server.send_sync("execution_error", mes, server.client_id) + memedeck_worker.send_sync("execution_error", mes, memedeck_worker.ws_id) return ExecutionBlocker(None) else: return block @@ -335,8 +344,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp }, "output": output_ui }) - if server.client_id is not None: + if server.client_id is not None and not is_memedeck: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + elif is_memedeck: + memedeck_worker.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, memedeck_worker.ws_id) + if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -414,9 +426,12 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None): + def __init__(self, server, memedeck_worker, lru_size=None): self.lru_size = lru_size self.server = server + self.memedeck_worker = memedeck_worker + # set logging level + logging.basicConfig(level=logging.INFO) self.reset() def reset(self): @@ -432,6 +447,9 @@ class PromptExecutor: self.status_messages.append((event, data)) if self.server.client_id is not None or broadcast: self.server.send_sync(event, data, self.server.client_id) + elif self.memedeck_worker.ws_id is not None: + self.memedeck_worker.send_sync(event, data, self.memedeck_worker.ws_id) + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -466,8 +484,12 @@ class PromptExecutor: if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] + if "ws_id" in extra_data: + self.memedeck_worker.ws_id = extra_data["ws_id"] + self.memedeck_worker.websocket_node_id = extra_data["websocket_node_id"] else: self.server.client_id = None + self.memedeck_worker.ws_id = None self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) @@ -501,7 +523,7 @@ class PromptExecutor: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + result, error, ex = execute(self.server, self.memedeck_worker, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -512,7 +534,7 @@ class PromptExecutor: execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + self.add_message("execution_success", { "prompt_id": prompt_id, 'ws_id': self.memedeck_worker.ws_id }, broadcast=False) ui_outputs = {} meta_outputs = {} @@ -527,6 +549,7 @@ class PromptExecutor: "meta": meta_outputs, } self.server.last_node_id = None + self.memedeck_worker.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() @@ -869,8 +892,9 @@ def validate_prompt(prompt): MAXIMUM_HISTORY_SIZE = 10000 class PromptQueue: - def __init__(self, server): + def __init__(self, server, memedeck_worker): self.server = server + self.memedeck_worker = memedeck_worker self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) self.task_counter = 0 @@ -884,6 +908,7 @@ class PromptQueue: with self.mutex: heapq.heappush(self.queue, item) self.server.queue_updated() + self.memedeck_worker.queue_updated() self.not_empty.notify() def get(self, timeout=None): @@ -897,6 +922,7 @@ class PromptQueue: self.currently_running[i] = copy.deepcopy(item) self.task_counter += 1 self.server.queue_updated() + self.memedeck_worker.queue_updated() return (item, i) class ExecutionStatus(NamedTuple): @@ -922,6 +948,7 @@ class PromptQueue: } self.history[prompt[1]].update(history_result) self.server.queue_updated() + self.memedeck_worker.queue_updated() def get_current_queue(self): with self.mutex: @@ -938,6 +965,7 @@ class PromptQueue: with self.mutex: self.queue = [] self.server.queue_updated() + self.memedeck_worker.queue_updated() def delete_queue_item(self, function): with self.mutex: @@ -949,6 +977,8 @@ class PromptQueue: self.queue.pop(x) heapq.heapify(self.queue) self.server.queue_updated() + self.memedeck_worker.queue_updated() + return True return False diff --git a/main.py b/main.py index f6510c90..24f3d7a4 100644 --- a/main.py +++ b/main.py @@ -53,6 +53,14 @@ def apply_custom_paths(): logging.info(f"Setting user directory to: {user_dir}") folder_paths.set_user_directory(user_dir) +# --------------------------------------------------------------------------------------- +# memedeck imports +# --------------------------------------------------------------------------------------- +from memedeck import MemedeckWorker + +import sys +sys.stdout = open(os.devnull, 'w') # disable all print statements +# --------------------------------------------------------------------------------------- def execute_prestartup_script(): def execute_script(script_path): @@ -153,9 +161,11 @@ def cuda_malloc_warning(): logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def prompt_worker(q, server_instance): +def prompt_worker(q, server, memedeck_worker): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) + e = execution.PromptExecutor(server, memedeck_worker, lru_size=args.cache_lru) + + # threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start() last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -209,12 +219,12 @@ def prompt_worker(q, server_instance): need_gc = False -async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): +async def run(server_instance, memedeck_worker, address='', port=8188, verbose=True, call_on_start=None): addresses = [] for addr in address.split(","): addresses.append((addr, port)) await asyncio.gather( - server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() + server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop(), memedeck_worker.publish_loop() ) @@ -224,11 +234,34 @@ def hijack_progress(server_instance): progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} server_instance.send_sync("progress", progress, server_instance.client_id) + if memedeck_worker.ws_id is not None: + memedeck_worker.send_sync("progress", {"value": value, "max": total, "prompt_id": memedeck_worker.last_prompt_id, "node": server_instance.last_node_id}, memedeck_worker.ws_id) if preview_image is not None: server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) - + if memedeck_worker.ws_id is not None: + memedeck_worker.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, memedeck_worker.ws_id) comfy.utils.set_progress_bar_global_hook(hook) +# async def run(server, memedeck_worker, address='', port=8188, verbose=True, call_on_start=None): +# addresses = [] +# for addr in address.split(","): +# addresses.append((addr, port)) +# # add memedeck worker publish loop +# await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop(), memedeck_worker.publish_loop()) + + +# def hijack_progress(server, memedeck_worker): +# def hook(value, total, preview_image): +# comfy.model_management.throw_exception_if_processing_interrupted() +# server.send_sync("progress", {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}, server.client_id) +# if memedeck_worker.ws_id is not None: +# memedeck_worker.send_sync("progress", {"value": value, "max": total, "prompt_id": memedeck_worker.last_prompt_id, "node": server.last_node_id}, memedeck_worker.ws_id) +# if preview_image is not None: +# server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) +# if memedeck_worker.ws_id is not None: +# memedeck_worker.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, memedeck_worker.ws_id) +# comfy.utils.set_progress_bar_global_hook(hook) + def cleanup_temp(): temp_dir = folder_paths.get_temp_directory() @@ -258,16 +291,54 @@ def start_comfyui(asyncio_loop=None): asyncio_loop = asyncio.new_event_loop() asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) - q = execution.PromptQueue(prompt_server) + memedeck_worker = MemedeckWorker(asyncio_loop) + q = execution.PromptQueue(prompt_server, memedeck_worker) + + # loop = asyncio.new_event_loop() + # asyncio.set_event_loop(loop) + # server = server.PromptServer(loop) + # q = execution.PromptQueue(server, memedeck_worker) + + extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + if os.path.isfile(extra_model_paths_config_path): + utils.extra_config.load_extra_path_config(extra_model_paths_config_path) + + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + utils.extra_config.load_extra_path_config(config_path) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) cuda_malloc_warning() prompt_server.add_routes() - hijack_progress(prompt_server) + hijack_progress(server, memedeck_worker) - threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(q, server, memedeck_worker)).start() + threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start() + # set logging level to info + + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) + logging.info(f"Setting output directory to: {output_dir}") + folder_paths.set_output_directory(output_dir) + + #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes + folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) + folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) + folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) + folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models")) + folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) + + if args.input_directory: + input_dir = os.path.abspath(args.input_directory) + logging.info(f"Setting input directory to: {input_dir}") + folder_paths.set_input_directory(input_dir) + + if args.user_directory: + user_dir = os.path.abspath(args.user_directory) + logging.info(f"Setting user directory to: {user_dir}") + folder_paths.set_user_directory(user_dir) if args.quick_test_for_ci: exit(0) @@ -287,6 +358,7 @@ def start_comfyui(asyncio_loop=None): async def start_all(): await prompt_server.setup() await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) + await run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) # Returning these so that other code can integrate with the ComfyUI loop and server return asyncio_loop, prompt_server, start_all @@ -298,6 +370,8 @@ if __name__ == "__main__": event_loop, _, start_all_func = start_comfyui() try: event_loop.run_until_complete(start_all_func()) + # event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) + except KeyboardInterrupt: logging.info("\nStopped server") diff --git a/memedeck-v1.py b/memedeck-v1.py new file mode 100644 index 00000000..abfe7592 --- /dev/null +++ b/memedeck-v1.py @@ -0,0 +1,266 @@ +import asyncio +import base64 +from io import BytesIO +import os +import logging +import signal +import struct +from typing import Optional +import uuid +from PIL import Image, ImageOps +from functools import partial + +import pika +import json + +import requests +import aiohttp + +# load from env file +# load from .env file +from dotenv import load_dotenv +load_dotenv() + +amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev' + +# define the enum in python +from enum import Enum + +class QueueProgressKind(Enum): + # make json serializable + ImageGenerated = "image_generated" + ImageGenerating = "image_generating" + SamePrompt = "same_prompt" + FaceswapGenerated = "faceswap_generated" + FaceswapGenerating = "faceswap_generating" + Failed = "failed" + +class MemedeckWorker: + class BinaryEventTypes: + PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 + + class JsonEventTypes(Enum): + PROGRESS = "progress" + EXECUTING = "executing" + EXECUTED = "executed" + ERROR = "error" + STATUS = "status" + + """ + MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api + it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api + """ + def __init__(self, loop): + MemedeckWorker.instance = self + # set logging level to info + logging.getLogger().setLevel(logging.INFO) + self.active_tasks_map = {} + self.current_task = None + + self.client_id = None + self.ws_id = None + self.websocket_node_id = None + self.current_node = None + self.current_progress = 0 + self.current_context = None + + self.loop = loop + self.messages = asyncio.Queue() + + self.http_client = None + self.prompt_queue = None + self.validate_prompt = None + self.last_prompt_id = None + + self.amqp_url = amqp_addr + self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue' + self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2' + self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383' + + # Configure logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + self.logger = logging.getLogger(__name__) + self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n") + + def on_connection_open(self, connection): + self.connection = connection + self.connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel): + self.channel = channel + + # only consume one message at a time + self.channel.basic_qos(prefetch_size=0, prefetch_count=1) + self.channel.queue_declare(queue=self.queue_name, durable=True) + self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received) + + def start(self, prompt_queue, validate_prompt): + self.prompt_queue = prompt_queue + self.validate_prompt = validate_prompt + + parameters = pika.URLParameters(self.amqp_url) + logging.getLogger('pika').setLevel(logging.WARNING) # supress all logs from pika + self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open) + + try: + self.connection.ioloop.start() + except KeyboardInterrupt: + self.connection.close() + self.connection.ioloop.start() + + def on_message_received(self, channel, method, properties, body): + decoded_string = body.decode('utf-8') + json_object = json.loads(decoded_string) + payload = json_object[1] + + # execute the task + prompt = payload["nodes"] + valid = self.validate_prompt(prompt) + + self.current_node = None + self.current_progress = 0 + self.websocket_node_id = None + self.ws_id = payload["source_ws_id"] + self.current_context = payload["req_ctx"] + + for node in prompt: # search through prompt nodes for websocket_node_id + if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": + self.websocket_node_id = node + break + + if valid[0]: + prompt_id = str(uuid.uuid4()) + outputs_to_execute = valid[2] + self.active_tasks_map[payload["source_ws_id"]] = { + "prompt_id": prompt_id, + "prompt": prompt, + "outputs_to_execute": outputs_to_execute, + "client_id": "memedeck-1", + "is_memedeck": True, + "websocket_node_id": self.websocket_node_id, + "ws_id": payload["source_ws_id"], + "context": payload["req_ctx"], + "current_node": None, + "current_progress": 0, + } + self.prompt_queue.put((0, prompt_id, prompt, { + "client_id": "memedeck-1", + 'is_memedeck': True, + 'websocket_node_id': self.websocket_node_id, + 'ws_id': payload["source_ws_id"], + 'context': payload["req_ctx"] + }, outputs_to_execute)) + self.set_last_prompt_id(prompt_id) + channel.basic_ack(delivery_tag=method.delivery_tag) # ack the task + else: + channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message + + # -------------------------------------------------- + # callbacks for the prompt queue + # -------------------------------------------------- + def queue_updated(self): + # print json of the queue info but only print the first 100 lines + info = self.get_queue_info() + # update_type = info[''] + # self.send_sync("status", { "status": self.get_queue_info() }) + + def get_queue_info(self): + prompt_info = {} + exec_info = {} + exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() + prompt_info['exec_info'] = exec_info + return prompt_info + + def send_sync(self, event, data, sid=None): + + self.loop.call_soon_threadsafe( + self.messages.put_nowait, (event, data, sid)) + + def set_last_prompt_id(self, prompt_id): + self.last_prompt_id = prompt_id + + async def publish_loop(self): + while True: + msg = await self.messages.get() + await self.send(*msg) + + async def send(self, event, data, sid=None): + current_task = self.active_tasks_map.get(sid) + if current_task is None or current_task['ws_id'] != sid: + return + + if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: # preview and unencoded images are sent here + self.logger.info(f"[memedeck]: sending image preview for {sid}") + await self.send_preview(data, sid=current_task['ws_id'], progress=current_task['current_progress'], context=current_task['context']) + else: # send json data / text data + if event == "executing": + current_task['current_node'] = data['node'] + elif event == "executed": + self.logger.info(f"---> [memedeck]: executed event for {sid}") + prompt_id = data['prompt_id'] + if prompt_id in self.active_tasks_map: + del self.active_tasks_map[prompt_id] + elif event == "progress": + if current_task['current_node'] == current_task['websocket_node_id']: # if the node is the websocket node, then set the progress to 100 + current_task['current_progress'] = 100 + else: # if the node is not the websocket node, then set the progress to the progress from the node + current_task['current_progress'] = data['value'] / data['max'] * 100 + if current_task['current_progress'] == 100 and current_task['current_node'] != current_task['websocket_node_id']: + # in case the progress is 100 but the node is not the websocket node, then set the progress to 95 + current_task['current_progress'] = 95 # this allows the full resolution image to be sent on the 100 progress event + + if data['value'] == 1: # if the value is 1, then send started to api + start_data = { + "ws_id": current_task['ws_id'], + "status": "started", + "info": None, + } + await self.send_to_api(start_data) + + elif event == "status": + self.logger.info(f"[memedeck]: sending status event: {data}") + + self.active_tasks_map[sid] = current_task + + + async def send_preview(self, image_data, sid=None, progress=None, context=None): + # if self.current_progress is odd, then don't send the preview + if progress % 2 == 1: + return + + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + + bytesIO = BytesIO() + image.save(bytesIO, format=image_type, quality=100 if progress == 96 else 75, compress_level=1) + preview_bytes = bytesIO.getvalue() + + ai_queue_progress = { + "ws_id": sid, + "kind": "image_generating" if progress < 100 else "image_generated", + "data": list(preview_bytes), + "progress": int(progress), + "context": context + } + + await self.send_to_api(ai_queue_progress) + + async def send_to_api(self, data): + if self.websocket_node_id is None: # check if the node is still running + logging.error(f"[memedeck]: websocket_node_id is None for {data['ws_id']}") + return + + try: + post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data) + await self.loop.run_in_executor(None, post_func) + except Exception as e: + self.logger.error(f"[memedeck]: error sending to api: {e}") diff --git a/memedeck.py b/memedeck.py new file mode 100644 index 00000000..7a9db766 --- /dev/null +++ b/memedeck.py @@ -0,0 +1,568 @@ +import asyncio +import base64 +from io import BytesIO +import os +import logging +import signal +import struct +import time +from typing import Optional +import uuid +from PIL import Image, ImageOps +from functools import partial + +import pika +import json + +import requests +import aiohttp + +from dotenv import load_dotenv +load_dotenv() + +amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev' + +from enum import Enum + +class QueueProgressKind(Enum): + ImageGenerated = "image_generated" + ImageGenerating = "image_generating" + SamePrompt = "same_prompt" + FaceswapGenerated = "faceswap_generated" + FaceswapGenerating = "faceswap_generating" + Failed = "failed" + +class MemedeckWorker: + class BinaryEventTypes: + PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 + + class JsonEventTypes(Enum): + PROGRESS = "progress" + EXECUTING = "executing" + EXECUTED = "executed" + ERROR = "error" + STATUS = "status" + + """ + MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api + it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api + """ + def __init__(self, loop): + MemedeckWorker.instance = self + logging.getLogger().setLevel(logging.INFO) + + self.loop = loop + self.messages = asyncio.Queue() + + self.http_client = None + self.prompt_queue = None + self.validate_prompt = None + self.last_prompt_id = None + + self.amqp_url = amqp_addr + self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue' + self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2' + self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383' + + # Internal job queue + self.internal_job_queue = asyncio.Queue() + + # Dictionary to keep track of tasks by ws_id + self.tasks_by_ws_id = {} + + # Configure logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + self.logger = logging.getLogger(__name__) + self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n") + + def on_connection_open(self, connection): + self.connection = connection + self.connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel): + self.channel = channel + + # only consume one message at a time + self.channel.basic_qos(prefetch_size=0, prefetch_count=1) + self.channel.queue_declare(queue=self.queue_name, durable=True) + self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False) + + def start(self, prompt_queue, validate_prompt): + self.prompt_queue = prompt_queue + self.validate_prompt = validate_prompt + + # Start the process_job_queue task **after** prompt_queue is set + self.loop.create_task(self.process_job_queue()) + + parameters = pika.URLParameters(self.amqp_url) + logging.getLogger('pika').setLevel(logging.WARNING) # suppress all logs from pika + self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open) + + try: + self.connection.ioloop.start() + except KeyboardInterrupt: + self.connection.close() + self.connection.ioloop.start() + + def on_message_received(self, channel, method, properties, body): + decoded_string = body.decode('utf-8') + json_object = json.loads(decoded_string) + payload = json_object[1] + + # execute the task + prompt = payload["nodes"] + valid = self.validate_prompt(prompt) + + # Prepare task_info + prompt_id = str(uuid.uuid4()) + outputs_to_execute = valid[2] + task_info = { + "prompt_id": prompt_id, + "prompt": prompt, + "outputs_to_execute": outputs_to_execute, + "client_id": "memedeck-1", + "is_memedeck": True, + "websocket_node_id": None, + "ws_id": payload["source_ws_id"], + "context": payload["req_ctx"], + "current_node": None, + "current_progress": 0, + "delivery_tag": method.delivery_tag, + "task_status": "waiting", + } + + # Find the websocket_node_id + for node in prompt: + if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket": + task_info['websocket_node_id'] = node + break + + if valid[0]: + # Enqueue the task into the internal job queue + self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info)) + self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}") + else: + channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message + + async def process_job_queue(self): + while True: + prompt_id, prompt, task_info = await self.internal_job_queue.get() + # Start a new coroutine for each task + self.loop.create_task(self.process_task(prompt_id, prompt, task_info)) + + async def process_task(self, prompt_id, prompt, task_info): + ws_id = task_info['ws_id'] + # Add the task to tasks_by_ws_id + self.tasks_by_ws_id[ws_id] = task_info + # Put the prompt into the prompt_queue + self.prompt_queue.put((0, prompt_id, prompt, { + "client_id": task_info["client_id"], + 'is_memedeck': task_info['is_memedeck'], + 'websocket_node_id': task_info['websocket_node_id'], + 'ws_id': task_info['ws_id'], + 'context': task_info['context'] + }, task_info['outputs_to_execute'])) + # Acknowledge the message + self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task + self.logger.info(f"[memedeck]: Acked task {prompt_id} {ws_id}") + + self.logger.info(f"[memedeck]: Started processing prompt {prompt_id}") + # Wait until the current task is completed + await self.wait_for_task_completion(ws_id) + # Task is done + self.internal_job_queue.task_done() + + async def wait_for_task_completion(self, ws_id): + """ + Wait until the task with the given ws_id is completed. + """ + while ws_id in self.tasks_by_ws_id: + await asyncio.sleep(0.5) + + # -------------------------------------------------- + # callbacks for the prompt queue + # -------------------------------------------------- + def queue_updated(self): + info = self.get_queue_info() + # self.send_sync("status", { "status": self.get_queue_info() }) + + def get_queue_info(self): + prompt_info = {} + exec_info = {} + exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() + prompt_info['exec_info'] = exec_info + return prompt_info + + def send_sync(self, event, data, sid=None): + self.loop.call_soon_threadsafe( + self.messages.put_nowait, (event, data, sid)) + + async def publish_loop(self): + while True: + msg = await self.messages.get() + await self.send(*msg) + + async def send(self, event, data, sid=None): + if sid is None: + self.logger.warning("Received event without sid") + return + + # Retrieve the task based on sid + task = self.tasks_by_ws_id.get(sid) + if not task: + self.logger.warning(f"Received event {event} for unknown sid: {sid}") + return + + if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + await self.send_preview( + data, + sid=sid, + progress=task['current_progress'], + context=task['context'] + ) + else: + # Send JSON data / text data + if event == "executing": + task['current_node'] = data['node'] + task["task_status"] = "executing" + elif event == "progress": + if task['current_node'] == task['websocket_node_id']: + # If the node is the websocket node, then set the progress to 100 + task['current_progress'] = 100 + else: + # If the node is not the websocket node, then set the progress based on the node's progress + task['current_progress'] = (data['value'] / data['max']) * 100 + if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']: + # In case the progress is 100 but the node is not the websocket node, set progress to 95 + task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event + + if data['value'] == 1: + # If the value is 1, send started to API + start_data = { + "ws_id": task['ws_id'], + "status": "started", + "info": None, + } + task["task_status"] = "executing" + await self.send_to_api(start_data) + + elif event == "status": + self.logger.info(f"[memedeck]: sending status event: {data}") + + # Update the task in tasks_by_ws_id + self.tasks_by_ws_id[sid] = task + + async def send_preview(self, image_data, sid=None, progress=None, context=None): + if sid is None: + self.logger.warning("Received preview without sid") + return + + task = self.tasks_by_ws_id.get(sid) + if not task: + self.logger.warning(f"Received preview for unknown sid: {sid}") + return + + if progress is None: + progress = task['current_progress'] + + # if progress is odd, then don't send the preview + if int(progress) % 2 == 1: + return + + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + + bytesIO = BytesIO() + image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1) + preview_bytes = bytesIO.getvalue() + + ai_queue_progress = { + "ws_id": sid, + "kind": "image_generating" if progress < 100 else "image_generated", + "data": list(preview_bytes), + "progress": int(progress), + "context": context + } + + await self.send_to_api(ai_queue_progress) + + if progress == 100: + del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id + self.logger.info(f"[memedeck]: Task {sid} completed") + + async def send_to_api(self, data): + ws_id = data.get('ws_id') + if not ws_id: + self.logger.error("[memedeck]: Missing ws_id in data") + return + task = self.tasks_by_ws_id.get(ws_id) + if not task: + self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}") + return + if task['websocket_node_id'] is None: + self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}") + return + try: + post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data) + await self.loop.run_in_executor(None, post_func) + except Exception as e: + self.logger.error(f"[memedeck]: error sending to api: {e}") + +# -------------------------------------------------------------------------- +# MemedeckAzureStorage +# -------------------------------------------------------------------------- +# from azure.storage.blob.aio import BlobClient, BlobServiceClient +# from azure.storage.blob import ContentSettings +# from typing import Optional, Tuple +# import cairosvg + +# WATERMARK = '' +# WATERMARK_SIZE = 40 + +# class MemedeckAzureStorage: +# def __init__(self, connection_string): +# # get environment variables +# self.storage_account = os.getenv('STORAGE_ACCOUNT') +# self.storage_access_key = os.getenv('STORAGE_ACCESS_KEY') +# self.storage_container = os.getenv('STORAGE_CONTAINER') +# self.logger = logging.getLogger(__name__) + +# self.blob_service_client = BlobServiceClient.from_connection_string(conn_str=connection_string) + +# async def upload_image( +# self, +# by: str, +# image_id: str, +# source_url: Optional[str], +# bytes_data: Optional[bytes], +# filetype: Optional[str], +# ) -> Tuple[str, Tuple[int, int]]: +# """ +# Uploads an image to Azure Blob Storage. + +# Args: +# by (str): Identifier for the uploader. +# image_id (str): Unique identifier for the image. +# source_url (Optional[str]): URL to fetch the image from. +# bytes_data (Optional[bytes]): Image data in bytes. +# filetype (Optional[str]): Desired file type (e.g., 'jpeg', 'png'). + +# Returns: +# Tuple[str, Tuple[int, int]]: URL of the uploaded image and its dimensions. +# """ +# # Retrieve image bytes either from the provided bytes_data or by fetching from source_url +# if source_url is None: +# if bytes_data is None: +# raise ValueError("Could not get image bytes") +# image_bytes = bytes_data +# else: +# self.logger.info(f"Requesting image from URL: {source_url}") +# async with aiohttp.ClientSession() as session: +# try: +# async with session.get(source_url) as response: +# if response.status != 200: +# raise Exception(f"Failed to fetch image, status code {response.status}") +# image_bytes = await response.read() +# except Exception as e: +# raise Exception(f"Error fetching image from URL: {e}") + +# # Open image using Pillow to get dimensions and format +# try: +# img = Image.open(BytesIO(image_bytes)) +# width, height = img.size +# inferred_filetype = img.format.lower() +# except Exception as e: +# raise Exception(f"Failed to decode image: {e}") + +# # Determine the final file type +# final_filetype = filetype.lower() if filetype else inferred_filetype + +# # Construct the blob name +# blob_name = f"{by}/{image_id.replace('image:', '')}.{final_filetype}" + +# # Upload the image to Azure Blob Storage +# try: +# image_url = await self.save_image(blob_name, img.format, image_bytes) +# return image_url, (width, height) +# except Exception as e: +# self.logger.error(f"Trouble saving image: {e}") +# raise Exception(f"Trouble saving image: {e}") + +# async def save_image( +# self, +# blob_name: str, +# content_type: str, +# bytes_data: bytes +# ) -> str: +# """ +# Saves image bytes to Azure Blob Storage. + +# Args: +# blob_name (str): Name of the blob in Azure Storage. +# content_type (str): MIME type of the content. +# bytes_data (bytes): Image data in bytes. + +# Returns: +# str: URL of the uploaded blob. +# """ +# # Retrieve environment variables +# account = os.getenv("STORAGE_ACCOUNT") +# access_key = os.getenv("STORAGE_ACCESS_KEY") +# container = os.getenv("STORAGE_CONTAINER") + +# if not all([account, access_key, container]): +# raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables") + +# # Initialize BlobServiceClient +# blob_service_client = BlobServiceClient( +# account_url=f"https://{account}.blob.core.windows.net", +# credential=access_key +# ) +# blob_client = blob_service_client.get_blob_client(container=container, blob=blob_name) + +# # Upload the blob +# try: +# await blob_client.upload_blob( +# bytes_data, +# overwrite=True, +# content_settings=ContentSettings(content_type=content_type) +# ) +# except Exception as e: +# raise Exception(f"Failed to upload blob: {e}") + +# self.logger.debug(f"Blob uploaded: name={blob_name}, content_type={content_type}") + +# # Construct and return the blob URL +# blob_url = f"https://media.memedeck.xyz//{container}/{blob_name}" +# return blob_url + +# async def add_watermark( +# self, +# base_blob_name: str, +# base_image: bytes +# ) -> str: +# """ +# Adds a watermark to the provided image and uploads the watermarked image. + +# Args: +# base_blob_name (str): Original blob name of the image. +# base_image (bytes): Image data in bytes. + +# Returns: +# str: URL of the watermarked image. +# """ +# # Load the input image +# try: +# img = Image.open(BytesIO(base_image)).convert("RGBA") +# except Exception as e: +# raise Exception(f"Failed to load image: {e}") + +# # Calculate position for the watermark (bottom right corner with padding) +# padding = 12 +# x = img.width - WATERMARK_SIZE - padding +# y = img.height - WATERMARK_SIZE - padding + +# # Analyze background brightness where the watermark will be placed +# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE) +# self.logger.info(f"Background brightness: {background_brightness}") + +# # Render SVG watermark to PNG bytes using cairosvg +# try: +# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE) +# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA") +# except Exception as e: +# raise Exception(f"Failed to render watermark SVG: {e}") + +# # Determine watermark color based on background brightness +# if background_brightness > 128: +# # Dark watermark for light backgrounds +# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity +# else: +# # Light watermark for dark backgrounds +# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity + +# # Apply the watermark color by blending +# solid_color = Image.new("RGBA", watermark.size, watermark_color) +# watermark = Image.alpha_composite(watermark, solid_color) + +# # Overlay the watermark onto the original image +# img.paste(watermark, (x, y), watermark) + +# # Save the watermarked image to bytes +# buffer = BytesIO() +# img = img.convert("RGB") # Convert back to RGB for JPEG format +# img.save(buffer, format="JPEG") +# buffer.seek(0) +# jpeg_bytes = buffer.read() + +# # Modify the blob name to include '_watermarked' +# try: +# if "memes/" in base_blob_name: +# base_blob_name_right = base_blob_name.split("memes/", 1)[1] +# else: +# base_blob_name_right = base_blob_name +# base_blob_name_split = base_blob_name_right.rsplit(".", 1) +# base_blob_name_without_extension = base_blob_name_split[0] +# extension = base_blob_name_split[1] +# except Exception as e: +# raise Exception(f"Failed to process blob name: {e}") + +# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}" + +# # Upload the watermarked image +# try: +# watermarked_blob_url = await self.save_image( +# watermarked_blob_name, +# "image/jpeg", +# jpeg_bytes +# ) +# return watermarked_blob_url +# except Exception as e: +# raise Exception(f"Failed to upload watermarked image: {e}") + +# def analyze_background_brightness( +# self, +# img: Image.Image, +# x: int, +# y: int, +# size: int +# ) -> int: +# """ +# Analyzes the brightness of a specific region in the image. + +# Args: +# img (Image.Image): The image to analyze. +# x (int): X-coordinate of the top-left corner of the region. +# y (int): Y-coordinate of the top-left corner of the region. +# size (int): Size of the square region to analyze. + +# Returns: +# int: Average brightness (0-255) of the region. +# """ +# # Crop the specified region +# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB") + +# # Calculate average brightness using the luminance formula +# total_brightness = 0 +# pixel_count = 0 +# for pixel in sub_image.getdata(): +# r, g, b = pixel +# brightness = (r * 299 + g * 587 + b * 114) // 1000 +# total_brightness += brightness +# pixel_count += 1 + +# if pixel_count == 0: +# return 0 + +# average_brightness = total_brightness // pixel_count +# return average_brightness + diff --git a/models/checkpoints/put_checkpoints_here b/models/checkpoints/put_checkpoints_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/clip/put_clip_or_text_encoder_models_here b/models/clip/put_clip_or_text_encoder_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/clip_vision/put_clip_vision_models_here b/models/clip_vision/put_clip_vision_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/configs/anything_v3.yaml b/models/configs/anything_v3.yaml deleted file mode 100644 index 8bcfe584..00000000 --- a/models/configs/anything_v3.yaml +++ /dev/null @@ -1,73 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - layer: "hidden" - layer_idx: -2 diff --git a/models/configs/v1-inference.yaml b/models/configs/v1-inference.yaml deleted file mode 100644 index d4effe56..00000000 --- a/models/configs/v1-inference.yaml +++ /dev/null @@ -1,70 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/models/configs/v1-inference_clip_skip_2.yaml b/models/configs/v1-inference_clip_skip_2.yaml deleted file mode 100644 index 8bcfe584..00000000 --- a/models/configs/v1-inference_clip_skip_2.yaml +++ /dev/null @@ -1,73 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - layer: "hidden" - layer_idx: -2 diff --git a/models/configs/v1-inference_clip_skip_2_fp16.yaml b/models/configs/v1-inference_clip_skip_2_fp16.yaml deleted file mode 100644 index 7eca31c7..00000000 --- a/models/configs/v1-inference_clip_skip_2_fp16.yaml +++ /dev/null @@ -1,74 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - layer: "hidden" - layer_idx: -2 diff --git a/models/configs/v1-inference_fp16.yaml b/models/configs/v1-inference_fp16.yaml deleted file mode 100644 index 147f42b1..00000000 --- a/models/configs/v1-inference_fp16.yaml +++ /dev/null @@ -1,71 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/models/configs/v1-inpainting-inference.yaml b/models/configs/v1-inpainting-inference.yaml deleted file mode 100644 index 45f3f82d..00000000 --- a/models/configs/v1-inpainting-inference.yaml +++ /dev/null @@ -1,71 +0,0 @@ -model: - base_learning_rate: 7.5e-05 - target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: hybrid # important - monitor: val/loss_simple_ema - scale_factor: 0.18215 - finetune_keys: null - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 9 # 4 data + 4 downscaled image + 1 mask - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - diff --git a/models/configs/v2-inference-v.yaml b/models/configs/v2-inference-v.yaml deleted file mode 100644 index 8ec8dfbf..00000000 --- a/models/configs/v2-inference-v.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - parameterization: "v" - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" diff --git a/models/configs/v2-inference-v_fp32.yaml b/models/configs/v2-inference-v_fp32.yaml deleted file mode 100644 index d5c9b9cb..00000000 --- a/models/configs/v2-inference-v_fp32.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - parameterization: "v" - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: False - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" diff --git a/models/configs/v2-inference.yaml b/models/configs/v2-inference.yaml deleted file mode 100644 index 152c4f3c..00000000 --- a/models/configs/v2-inference.yaml +++ /dev/null @@ -1,67 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" diff --git a/models/configs/v2-inference_fp32.yaml b/models/configs/v2-inference_fp32.yaml deleted file mode 100644 index 0d03231f..00000000 --- a/models/configs/v2-inference_fp32.yaml +++ /dev/null @@ -1,67 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: False - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" diff --git a/models/configs/v2-inpainting-inference.yaml b/models/configs/v2-inpainting-inference.yaml deleted file mode 100644 index 32a9471d..00000000 --- a/models/configs/v2-inpainting-inference.yaml +++ /dev/null @@ -1,158 +0,0 @@ -model: - base_learning_rate: 5.0e-05 - target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: hybrid - scale_factor: 0.18215 - monitor: val/loss_simple_ema - finetune_keys: null - use_ema: False - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - image_size: 32 # unused - in_channels: 9 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - - -data: - target: ldm.data.laion.WebDataModuleFromConfig - params: - tar_base: null # for concat as in LAION-A - p_unsafe_threshold: 0.1 - filter_word_list: "data/filters.yaml" - max_pwatermark: 0.45 - batch_size: 8 - num_workers: 6 - multinode: True - min_size: 512 - train: - shards: - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" - shuffle: 10000 - image_key: jpg - image_transforms: - - target: torchvision.transforms.Resize - params: - size: 512 - interpolation: 3 - - target: torchvision.transforms.RandomCrop - params: - size: 512 - postprocess: - target: ldm.data.laion.AddMask - params: - mode: "512train-large" - p_drop: 0.25 - # NOTE use enough shards to avoid empty validation loops in workers - validation: - shards: - - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " - shuffle: 0 - image_key: jpg - image_transforms: - - target: torchvision.transforms.Resize - params: - size: 512 - interpolation: 3 - - target: torchvision.transforms.CenterCrop - params: - size: 512 - postprocess: - target: ldm.data.laion.AddMask - params: - mode: "512train-large" - p_drop: 0.25 - -lightning: - find_unused_parameters: True - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 10000 - - image_logger: - target: main.ImageLogger - params: - enable_autocast: False - disabled: False - batch_frequency: 1000 - max_images: 4 - increase_log_steps: False - log_first_step: False - log_images_kwargs: - use_ema_scope: False - inpaint: False - plot_progressive_rows: False - plot_diffusion_rows: False - N: 4 - unconditional_guidance_scale: 5.0 - unconditional_guidance_label: [""] - ddim_steps: 50 # todo check these out for depth2img, - ddim_eta: 0.0 # todo check these out for depth2img, - - trainer: - benchmark: True - val_check_interval: 5000000 - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 diff --git a/models/controlnet/put_controlnets_and_t2i_here b/models/controlnet/put_controlnets_and_t2i_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/diffusers/put_diffusers_models_here b/models/diffusers/put_diffusers_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/diffusion_models/put_diffusion_model_files_here b/models/diffusion_models/put_diffusion_model_files_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/embeddings/put_embeddings_or_textual_inversion_concepts_here b/models/embeddings/put_embeddings_or_textual_inversion_concepts_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/gligen/put_gligen_models_here b/models/gligen/put_gligen_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/hypernetworks/put_hypernetworks_here b/models/hypernetworks/put_hypernetworks_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/loras/put_loras_here b/models/loras/put_loras_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/photomaker/put_photomaker_models_here b/models/photomaker/put_photomaker_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/style_models/put_t2i_style_model_here b/models/style_models/put_t2i_style_model_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/unet/put_unet_files_here b/models/unet/put_unet_files_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/upscale_models/put_esrgan_and_other_upscale_models_here b/models/upscale_models/put_esrgan_and_other_upscale_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/vae/put_vae_here b/models/vae/put_vae_here deleted file mode 100644 index e69de29b..00000000 diff --git a/models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here b/models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here deleted file mode 100644 index e69de29b..00000000 diff --git a/pysssss-workflows/training.json b/pysssss-workflows/training.json new file mode 100644 index 00000000..ba347d92 --- /dev/null +++ b/pysssss-workflows/training.json @@ -0,0 +1 @@ +{"last_node_id": 141, "last_link_id": 256, "nodes": [{"id": 82, "type": "SomethingToString", "pos": {"0": 3867.1982421875, "1": 867.84130859375}, "size": {"0": 315, "1": 82}, "flags": {"collapsed": true, "pinned": true}, "order": 53, "mode": 0, "inputs": [{"name": "input", "type": "*", "link": 234}], "outputs": [{"name": "STRING", "type": "STRING", "links": [121], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "SomethingToString"}, "widgets_values": ["steps ", ""]}, {"id": 83, "type": "AddLabel", "pos": {"0": 4088.1982421875, "1": 867.84130859375}, "size": {"0": 315, "1": 274}, "flags": {"collapsed": true, "pinned": true}, "order": 59, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 122}, {"name": "caption", "type": "STRING", "link": null, "widget": {"name": "caption"}}, {"name": "text", "type": "STRING", "link": 121, "widget": {"name": "text"}}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [204], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "AddLabel"}, "widgets_values": [10, 2, 48, 32, "white", "black", "FreeMono.ttf", "Text", "up", ""]}, {"id": 4, "type": "FluxTrainLoop", "pos": {"0": 1416.79443359375, "1": -104.92213439941406}, "size": {"0": 393, "1": 78}, "flags": {"pinned": true}, "order": 31, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 181}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [162, 218], "slot_index": 0, "shape": 3}, {"name": "steps", "type": "INT", "links": [220], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainLoop"}, "widgets_values": [250], "color": "#232", "bgcolor": "#353"}, {"id": 79, "type": "SomethingToString", "pos": {"0": 1777.79443359375, "1": 857.0768432617188}, "size": {"0": 315, "1": 82}, "flags": {"collapsed": true, "pinned": true}, "order": 36, "mode": 0, "inputs": [{"name": "input", "type": "*", "link": 220}], "outputs": [{"name": "STRING", "type": "STRING", "links": [111], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "SomethingToString"}, "widgets_values": ["steps ", ""]}, {"id": 44, "type": "FluxTrainLoop", "pos": {"0": 2567.5634765625, "1": -99.79415893554688}, "size": {"0": 361.21923828125, "1": 78}, "flags": {"pinned": true}, "order": 39, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 219}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [164, 222], "slot_index": 0, "shape": 3}, {"name": "steps", "type": "INT", "links": [235], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainLoop"}, "widgets_values": [250], "color": "#232", "bgcolor": "#353"}, {"id": 45, "type": "FluxTrainValidate", "pos": {"0": 3302.5634765625, "1": -104.79415893554688}, "size": {"0": 312.3999938964844, "1": 49.711063385009766}, "flags": {"pinned": true}, "order": 47, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 227}, {"name": "validation_settings", "type": "VALSETTINGS", "link": 244}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [223], "slot_index": 0, "shape": 3}, {"name": "validation_images", "type": "IMAGE", "links": [70, 119], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainValidate"}, "widgets_values": [], "color": "#232", "bgcolor": "#353"}, {"id": 62, "type": "FluxTrainSave", "pos": {"0": 4053.1982421875, "1": -111.1568374633789}, "size": {"0": 335.8072204589844, "1": 122}, "flags": {"pinned": true}, "order": 52, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 224}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [225], "slot_index": 0, "shape": 3}, {"name": "lora_path", "type": "STRING", "links": null, "shape": 3}, {"name": "steps", "type": "INT", "links": [], "slot_index": 2, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainSave"}, "widgets_values": [true, false], "color": "#232", "bgcolor": "#353"}, {"id": 60, "type": "FluxTrainValidate", "pos": {"0": 4404.1982421875, "1": -109.1568374633789}, "size": {"0": 312.3999938964844, "1": 55.36396026611328}, "flags": {"pinned": true}, "order": 56, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 225}, {"name": "validation_settings", "type": "VALSETTINGS", "link": 245}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [226], "slot_index": 0, "shape": 3}, {"name": "validation_images", "type": "IMAGE", "links": [90, 122], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainValidate"}, "widgets_values": [], "color": "#232", "bgcolor": "#353"}, {"id": 61, "type": "PreviewImage", "pos": {"0": 3652.1982421875, "1": 91.84286499023438}, "size": {"0": 1064.4329833984375, "1": 704.880615234375}, "flags": {"pinned": true}, "order": 58, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 90}], "outputs": [], "properties": {"Node name for S&R": "PreviewImage"}, "widgets_values": [], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 70, "type": "VisualizeLoss", "pos": {"0": 4950.255859375, "1": 946.3529663085938}, "size": {"0": 254.40000915527344, "1": 198}, "flags": {"pinned": true}, "order": 64, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 217}], "outputs": [{"name": "plot", "type": "IMAGE", "links": [138], "slot_index": 0, "shape": 3}, {"name": "loss_list", "type": "FLOAT", "links": null, "shape": 3}], "properties": {"Node name for S&R": "VisualizeLoss"}, "widgets_values": ["ggplot", 1000, true, 768, 512, false]}, {"id": 48, "type": "GetNode", "pos": {"0": 3342.5634765625, "1": 3.2055797576904297}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 0, "mode": 0, "inputs": [], "outputs": [{"name": "VALSETTINGS", "type": "VALSETTINGS", "links": [244], "slot_index": 0}], "title": "Get_validation_settings", "properties": {}, "widgets_values": ["validation_settings"], "color": "#232", "bgcolor": "#353"}, {"id": 63, "type": "GetNode", "pos": {"0": 4466.3583984375, "1": 0.10303689539432526}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 1, "mode": 0, "inputs": [], "outputs": [{"name": "VALSETTINGS", "type": "VALSETTINGS", "links": [245], "slot_index": 0}], "title": "Get_validation_settings", "properties": {}, "widgets_values": ["validation_settings"], "color": "#232", "bgcolor": "#353"}, {"id": 59, "type": "FluxTrainLoop", "pos": {"0": 3648.1982421875, "1": -107.15684509277344}, "size": {"0": 381.93206787109375, "1": 78}, "flags": {"pinned": true}, "order": 48, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 223}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [166, 224], "slot_index": 0, "shape": 3}, {"name": "steps", "type": "INT", "links": [234], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainLoop"}, "widgets_values": [250], "color": "#232", "bgcolor": "#353"}, {"id": 38, "type": "SetNode", "pos": {"0": 1178, "1": 77}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 22, "mode": 0, "inputs": [{"name": "VALSETTINGS", "type": "VALSETTINGS", "link": 243}], "outputs": [{"name": "*", "type": "*", "links": null}], "title": "Set_validation_settings", "properties": {"previousName": "validation_settings"}, "widgets_values": ["validation_settings"], "color": "#232", "bgcolor": "#353"}, {"id": 133, "type": "FluxTrainEnd", "pos": {"0": 5903.7578125, "1": -89.38908386230469}, "size": {"0": 254.98214721679688, "1": 98}, "flags": {"pinned": true}, "order": 65, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 229}], "outputs": [{"name": "lora_name", "type": "STRING", "links": [231], "slot_index": 0, "shape": 3}, {"name": "metadata", "type": "STRING", "links": null, "shape": 3}, {"name": "lora_path", "type": "STRING", "links": [236], "slot_index": 2, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainEnd"}, "widgets_values": [false], "color": "#322", "bgcolor": "#533"}, {"id": 138, "type": "Fast Bypasser (rgthree)", "pos": {"0": -838, "1": 261}, "size": {"0": 304, "1": 78}, "flags": {"pinned": true}, "order": 25, "mode": 0, "inputs": [{"name": "Train 512x512 Dataset", "type": "*", "link": 252, "dir": 3}, {"name": "", "type": "*", "link": null, "dir": 3}], "outputs": [{"name": "OPT_CONNECTION", "type": "*", "links": [253], "slot_index": 0, "dir": 4}], "title": "512x512 Dataset switch (on/off)", "properties": {"toggleRestriction": "default"}, "color": "#232", "bgcolor": "#353"}, {"id": 137, "type": "Fast Bypasser (rgthree)", "pos": {"0": -499, "1": 262}, "size": {"0": 304, "1": 78}, "flags": {"pinned": true}, "order": 27, "mode": 0, "inputs": [{"name": "Train 768x768 Dataset", "type": "*", "link": 249, "dir": 3}, {"name": "", "type": "*", "link": null, "dir": 3}], "outputs": [{"name": "OPT_CONNECTION", "type": "*", "links": [250], "slot_index": 0, "dir": 4}], "title": "768x768 Dataset switch (on/off)", "properties": {"toggleRestriction": "default"}, "color": "#149454", "bgcolor": "#008040"}, {"id": 139, "type": "Fast Bypasser (rgthree)", "pos": {"0": -159, "1": 258}, "size": {"0": 320.79998779296875, "1": 78}, "flags": {"pinned": true}, "order": 29, "mode": 0, "inputs": [{"name": "Train 1024x1024 Dataset", "type": "*", "link": 256, "dir": 3, "label": ""}, {"name": "", "type": "*", "link": null, "dir": 3}], "outputs": [{"name": "OPT_CONNECTION", "type": "*", "links": [255], "slot_index": 0, "dir": 4}], "title": "1024x1024 Dataset switch (on/off)", "properties": {"toggleRestriction": "default"}, "color": "#149494", "bgcolor": "#008080"}, {"id": 123, "type": "GetNode", "pos": {"0": 5961.5439453125, "1": 67.45687866210938}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 2, "mode": 0, "inputs": [], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [202, 209], "slot_index": 0}], "title": "Get_Sampler1", "properties": {}, "widgets_values": ["Sampler1"], "color": "#322", "bgcolor": "#533"}, {"id": 117, "type": "ImageConcatFromBatch", "pos": {"0": 6119.5439453125, "1": 283.45684814453125}, "size": {"0": 315, "1": 106}, "flags": {"pinned": true}, "order": 24, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 195}, {"name": "num_columns", "type": "INT", "link": 199, "widget": {"name": "num_columns"}}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [210], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "ImageConcatFromBatch"}, "widgets_values": [3, false, 4096], "color": "#322", "bgcolor": "#533"}, {"id": 124, "type": "GetNode", "pos": {"0": 5964.5439453125, "1": 118.45687866210938}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 3, "mode": 0, "inputs": [], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [203], "slot_index": 0}], "title": "Get_Sampler2", "properties": {}, "widgets_values": ["Sampler2"], "color": "#322", "bgcolor": "#533"}, {"id": 126, "type": "GetNode", "pos": {"0": 5967.5439453125, "1": 169.45687866210938}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 4, "mode": 0, "inputs": [], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [206], "slot_index": 0}], "title": "Get_Sampler3", "properties": {}, "widgets_values": ["Sampler3"], "color": "#322", "bgcolor": "#533"}, {"id": 128, "type": "GetNode", "pos": {"0": 5966.5439453125, "1": 215.45687866210938}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 5, "mode": 0, "inputs": [], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [208], "slot_index": 0}], "title": "Get_Sampler4", "properties": {}, "widgets_values": ["Sampler4"], "color": "#322", "bgcolor": "#533"}, {"id": 119, "type": "ImageBatchMulti", "pos": {"0": 6213.5439453125, "1": 70.45687866210938}, "size": {"0": 210, "1": 142}, "flags": {"pinned": true}, "order": 20, "mode": 0, "inputs": [{"name": "image_1", "type": "IMAGE", "link": 202}, {"name": "image_2", "type": "IMAGE", "link": 203}, {"name": "image_3", "type": "IMAGE", "link": 206}, {"name": "image_4", "type": "IMAGE", "link": 208}], "outputs": [{"name": "images", "type": "IMAGE", "links": [195], "slot_index": 0, "shape": 3}], "properties": {}, "widgets_values": [4, null], "color": "#322", "bgcolor": "#533"}, {"id": 120, "type": "GetImageSizeAndCount", "pos": {"0": 6224.5439453125, "1": -30.543121337890625}, "size": {"0": 210, "1": 86}, "flags": {"collapsed": true, "pinned": true}, "order": 19, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 209}], "outputs": [{"name": "image", "type": "IMAGE", "links": [], "slot_index": 0, "shape": 3}, {"name": "1024 width", "type": "INT", "links": null, "shape": 3}, {"name": "1072 height", "type": "INT", "links": null, "shape": 3}, {"name": "1 count", "type": "INT", "links": [199], "slot_index": 3, "shape": 3}], "properties": {"Node name for S&R": "GetImageSizeAndCount"}, "widgets_values": [], "color": "#322", "bgcolor": "#533"}, {"id": 129, "type": "AddLabel", "pos": {"0": 6248.5439453125, "1": -95.54312133789062}, "size": {"0": 315, "1": 274}, "flags": {"collapsed": true, "pinned": true}, "order": 69, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 210}, {"name": "caption", "type": "STRING", "link": null, "widget": {"name": "caption"}}, {"name": "text", "type": "STRING", "link": 231, "widget": {"name": "text"}}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [214], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "AddLabel"}, "widgets_values": [10, 2, 48, 32, "white", "black", "FreeMono.ttf", "Text", "up", ""], "color": "#322", "bgcolor": "#533"}, {"id": 130, "type": "SaveImage", "pos": {"0": 6443.5439453125, "1": -99.54312133789062}, "size": {"0": 671.9758911132812, "1": 719.2059936523438}, "flags": {"pinned": true}, "order": 72, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 214}], "outputs": [], "properties": {"Node name for S&R": "SaveImage"}, "widgets_values": ["flux_lora_trainer_sheet"], "color": "#322", "bgcolor": "#533"}, {"id": 131, "type": "Note", "pos": {"0": -1180, "1": 449}, "size": {"0": 370.4340515136719, "1": 60.30167007446289}, "flags": {"pinned": true}, "order": 6, "mode": 0, "inputs": [], "outputs": [], "properties": {"text": ""}, "widgets_values": ["sanity check that all the args are chosen correctly"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 105, "type": "Display Any (rgthree)", "pos": {"0": -1182, "1": 569}, "size": {"0": 1346.8626708984375, "1": 348.1687927246094}, "flags": {"pinned": true}, "order": 33, "mode": 0, "inputs": [{"name": "source", "type": "*", "link": 183, "dir": 3}], "outputs": [], "properties": {"Node name for S&R": "Display Any (rgthree)"}, "widgets_values": ["Namespace(console_log_level=None, console_log_file=None, console_log_simple=False, v2=False, v_parameterization=False, pretrained_model_name_or_path='/home/holium/ComfyUI/models/unet/FLUX1/flux1Dev_v10.safetensors', tokenizer_cache_dir=None, train_data_dir=None, cache_info=False, shuffle_caption=False, caption_separator=',', caption_extension='.caption', caption_extention=None, keep_tokens=0, keep_tokens_separator='', secondary_separator=None, enable_wildcard=False, caption_prefix=None, caption_suffix=None, color_aug=False, flip_aug=False, face_crop_aug_range=None, random_crop=False, debug_dataset=False, resolution=None, cache_latents=True, vae_batch_size=1, cache_latents_to_disk=True, enable_bucket=False, min_bucket_reso=256, max_bucket_reso=1024, bucket_reso_steps=64, bucket_no_upscale=False, token_warmup_min=1, token_warmup_step=0.0, alpha_mask=False, dataset_class=None, caption_dropout_rate=0.0, caption_dropout_every_n_epochs=0, caption_tag_dropout_rate=0.0, reg_data_dir=None, in_json=None, dataset_repeats=1, output_dir='flux_loras', output_name='flux_bombo_rank16_bf16', huggingface_repo_id=None, huggingface_repo_type=None, huggingface_path_in_repo=None, huggingface_token=None, huggingface_repo_visibility=None, save_state_to_huggingface=False, resume_from_huggingface=False, async_upload=False, save_precision='bf16', save_every_n_epochs=None, save_every_n_steps=None, save_n_epoch_ratio=None, save_last_n_epochs=None, save_last_n_epochs_state=None, save_last_n_steps=None, save_last_n_steps_state=None, save_state=False, save_state_on_train_end=False, resume=None, train_batch_size=1, max_token_length=None, mem_eff_attn=True, torch_compile=False, dynamo_backend='inductor', xformers=True, sdpa=False, vae=None, max_train_steps=1000, max_train_epochs=None, max_data_loader_n_workers=0, persistent_data_loader_workers=False, seed=42, gradient_checkpointing=True, gradient_accumulation_steps=1, mixed_precision='bf16', full_fp16=False, full_bf16=True, fp8_base=False, ddp_timeout=None, ddp_gradient_as_bucket_view=False, ddp_static_graph=False, clip_skip=None, logging_dir=None, log_with=None, log_prefix=None, log_tracker_name=None, wandb_run_name=None, log_tracker_config=None, wandb_api_key=None, log_config=False, noise_offset=None, noise_offset_random_strength=False, multires_noise_iterations=None, ip_noise_gamma=None, ip_noise_gamma_random_strength=False, multires_noise_discount=0.3, adaptive_noise_scale=None, zero_terminal_snr=False, min_timestep=None, max_timestep=None, loss_type='l2', huber_schedule='snr', huber_c=0.1, lowram=False, highvram=False, sample_every_n_steps=None, sample_at_first=False, sample_every_n_epochs=None, sample_prompts=['h0d1 he is smoking in a nightclub with pistol pointed at the ceiling. he has a private booth with bottles of whiskey and vodka and several female blonde women cats'], sample_sampler='ddim', config_file=None, output_config=False, metadata_title=None, metadata_author=None, metadata_description=None, metadata_license=None, metadata_tags=None, prior_loss_weight=1.0, conditioning_data_dir=None, masked_loss=False, deepspeed=False, zero_stage=2, offload_optimizer_device=None, offload_optimizer_nvme_path=None, offload_param_device=None, offload_param_nvme_path=None, zero3_init_flag=False, zero3_save_16bit_model=False, fp16_master_weights_and_gradients=False, optimizer_type='adafactor', use_8bit_adam=False, use_lion_optimizer=False, learning_rate=0.00039999999999999996, max_grad_norm=1.0, optimizer_args=['relative_step=False', 'scale_parameter=False', 'warmup_init=False', 'clip_threshold=1.0'], lr_scheduler_type='', lr_scheduler_args=None, lr_scheduler='constant', lr_warmup_steps=0, lr_scheduler_num_cycles=3, lr_scheduler_power=1.0, fused_backward_pass=False, dataset_config='[[datasets]]\\nresolution = [ 512, 512,]\\nbatch_size = 4\\nenable_bucket = true\\nbucket_no_upscale = false\\nmin_bucket_reso = 256\\nmax_bucket_reso = 1024\\n[[datasets.subsets]]\\nimage_dir = \"input/hodi\"\\nclass_tokens = \"h0d1\"\\nnum_repeats = 1\\n\\n\\n[general]\\nshuffle_caption = false\\ncaption_extension = \".txt\"\\nkeep_tokens_separator = \"|||\"\\ncaption_dropout_rate = 0.0\\ncolor_aug = false\\nflip_aug = true\\n', min_snr_gamma=5.0, scale_v_pred_loss_like_noise_pred=False, v_pred_like_loss=None, debiased_estimation_loss=False, weighted_captions=False, no_metadata=False, save_model_as='safetensors', unet_lr=None, text_encoder_lr=0.0, network_weights=None, network_module='.networks.lora_flux', network_dim=16, network_alpha=80.0, network_dropout=None, network_args=['train_blocks=all'], network_train_unet_only=False, network_train_text_encoder_only=False, training_comment=None, dim_from_weights=False, scale_weight_norms=None, base_weights=None, base_weights_multiplier=None, no_half_vae=False, skip_until_initial_step=False, initial_epoch=None, initial_step=None, fp8_base_unet=False, cpu_offload_checkpointing=False, num_cpu_threads_per_process=1, clip_l='/home/holium/ComfyUI/models/clip/clip_l.safetensors', t5xxl='/home/holium/ComfyUI/models/clip/t5/google_t5-v1_1-xxl_encoderonly-fp16.safetensors', ae='/home/holium/ComfyUI/models/vae/FLUX1/ae.sft', t5xxl_max_token_length=512, spda=False, split_mode=False, apply_t5_attn_mask=True, cache_text_encoder_outputs=True, weighting_scheme='logit_normal', logit_mean=0.0, logit_std=1.0, mode_scale=1.29, timestep_sampling='shift', sigmoid_scale=1.0, model_prediction_type='raw', guidance_scale=1.0, discrete_flow_shift=3.1582000000000003, cache_text_encoder_outputs_to_disk=True)"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 2, "type": "FluxTrainModelSelect", "pos": {"0": 252, "1": 45}, "size": {"0": 430, "1": 130}, "flags": {"pinned": true}, "order": 7, "mode": 0, "inputs": [], "outputs": [{"name": "flux_models", "type": "TRAIN_FLUX_MODELS", "links": [179], "shape": 3}], "properties": {"Node name for S&R": "FluxTrainModelSelect"}, "widgets_values": ["FLUX1/flux1Dev_v10.safetensors", "FLUX1/ae.sft", "clip_l.safetensors", "t5/google_t5-v1_1-xxl_encoderonly-fp16.safetensors"], "color": "#1414ff", "bgcolor": "#0000ff"}, {"id": 114, "type": "OptimizerConfigAdafactor", "pos": {"0": 274, "1": 416}, "size": {"0": 315, "1": 316}, "flags": {"pinned": true}, "order": 8, "mode": 0, "inputs": [], "outputs": [{"name": "optimizer_settings", "type": "ARGS", "links": [247], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "OptimizerConfigAdafactor"}, "widgets_values": [1, "constant", 0, 3, 1, false, false, false, 1, 5, ""], "color": "#941414", "bgcolor": "#800000"}, {"id": 136, "type": "Switch any [Crystools]", "pos": {"0": 273, "1": 795}, "size": {"0": 315, "1": 78}, "flags": {"pinned": true}, "order": 21, "mode": 0, "inputs": [{"name": "on_true", "type": "*", "link": 247}, {"name": "on_false", "type": "*", "link": 246}], "outputs": [{"name": "*", "type": "*", "links": [248], "slot_index": 0, "shape": 3}], "title": "Optimizer switch (true = Adafactor)", "properties": {"Node name for S&R": "Switch any [Crystools]"}, "widgets_values": [true], "color": "#941414", "bgcolor": "#800000"}, {"id": 95, "type": "OptimizerConfig", "pos": {"0": 278, "1": 933}, "size": {"0": 315, "1": 244}, "flags": {"pinned": true}, "order": 9, "mode": 0, "inputs": [], "outputs": [{"name": "optimizer_settings", "type": "ARGS", "links": [246], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "OptimizerConfig"}, "widgets_values": ["adamw8bit", 1, "cosine_with_restarts", 4, 3, 1, 5, ""], "color": "#941414", "bgcolor": "#800000"}, {"id": 37, "type": "FluxTrainValidationSettings", "pos": {"0": 831, "1": 44}, "size": {"0": 315, "1": 250}, "flags": {"pinned": true}, "order": 10, "mode": 0, "inputs": [], "outputs": [{"name": "validation_settings", "type": "VALSETTINGS", "links": [243], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainValidationSettings"}, "widgets_values": [20, 1024, 1024, 3, 1055665447947593, "randomize", false, 0.5, 1.15], "color": "#232", "bgcolor": "#353"}, {"id": 9, "type": "PreviewImage", "pos": {"0": 1425, "1": 80}, "size": {"0": 1099.023193359375, "1": 710.3143920898438}, "flags": {"pinned": true}, "order": 40, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 8}], "outputs": [], "properties": {"Node name for S&R": "PreviewImage"}, "widgets_values": [], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 14, "type": "FluxTrainSave", "pos": {"0": 1828, "1": -112}, "size": {"0": 341.3186340332031, "1": 122}, "flags": {"pinned": true}, "order": 35, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 218}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [221], "slot_index": 0, "shape": 3}, {"name": "lora_path", "type": "STRING", "links": [], "slot_index": 1, "shape": 3}, {"name": "steps", "type": "INT", "links": [], "slot_index": 2, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainSave"}, "widgets_values": [true, false], "color": "#232", "bgcolor": "#353"}, {"id": 8, "type": "FluxTrainValidate", "pos": {"0": 2202, "1": -116}, "size": {"0": 321.6932373046875, "1": 51.817989349365234}, "flags": {"collapsed": false, "pinned": true}, "order": 38, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 221}, {"name": "validation_settings", "type": "VALSETTINGS", "link": 242}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [219], "slot_index": 0, "shape": 3}, {"name": "validation_images", "type": "IMAGE", "links": [8, 112], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainValidate"}, "widgets_values": [], "color": "#232", "bgcolor": "#353"}, {"id": 40, "type": "GetNode", "pos": {"0": 2257, "1": -2}, "size": {"0": 277.0899353027344, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 11, "mode": 0, "inputs": [], "outputs": [{"name": "VALSETTINGS", "type": "VALSETTINGS", "links": [242], "slot_index": 0}], "title": "Get_validation_settings", "properties": {}, "widgets_values": ["validation_settings"], "color": "#232", "bgcolor": "#353"}, {"id": 98, "type": "SaveImage", "pos": {"0": 1888, "1": 915}, "size": {"0": 645.9608764648438, "1": 439.37261962890625}, "flags": {"pinned": true}, "order": 37, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 161}], "outputs": [], "properties": {"Node name for S&R": "SaveImage"}, "widgets_values": ["flux_lora_loss_plot"]}, {"id": 97, "type": "VisualizeLoss", "pos": {"0": 1495, "1": 1000}, "size": {"0": 303.6300048828125, "1": 198}, "flags": {"pinned": true}, "order": 34, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 162}], "outputs": [{"name": "plot", "type": "IMAGE", "links": [161], "slot_index": 0, "shape": 3}, {"name": "loss_list", "type": "FLOAT", "links": null, "shape": 3}], "properties": {"Node name for S&R": "VisualizeLoss"}, "widgets_values": ["seaborn-v0_8-dark-palette", 100, true, 768, 512, false]}, {"id": 78, "type": "AddLabel", "pos": {"0": 1993, "1": 857}, "size": {"0": 315, "1": 274}, "flags": {"collapsed": true, "pinned": true}, "order": 41, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 112}, {"name": "caption", "type": "STRING", "link": null, "widget": {"name": "caption"}}, {"name": "text", "type": "STRING", "link": 111, "widget": {"name": "text"}}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [200], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "AddLabel"}, "widgets_values": [10, 2, 48, 32, "white", "black", "FreeMono.ttf", "Text", "up", ""]}, {"id": 121, "type": "SetNode", "pos": {"0": 2150, "1": 857}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 45, "mode": 0, "inputs": [{"name": "IMAGE", "type": "IMAGE", "link": 200}], "outputs": [{"name": "*", "type": "*", "links": null}], "title": "Set_Sampler1", "properties": {"previousName": "Sampler1"}, "widgets_values": ["Sampler1"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 99, "type": "VisualizeLoss", "pos": {"0": 2697, "1": 969}, "size": {"0": 254.40000915527344, "1": 198}, "flags": {"pinned": true}, "order": 42, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 164}], "outputs": [{"name": "plot", "type": "IMAGE", "links": [163], "slot_index": 0, "shape": 3}, {"name": "loss_list", "type": "FLOAT", "links": null, "shape": 3}], "properties": {"Node name for S&R": "VisualizeLoss"}, "widgets_values": ["seaborn-v0_8-dark-palette", 100, true, 768, 512, false]}, {"id": 100, "type": "SaveImage", "pos": {"0": 3035, "1": 929}, "size": {"0": 574.23046875, "1": 414.46881103515625}, "flags": {"pinned": true}, "order": 46, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 163}], "outputs": [], "properties": {"Node name for S&R": "SaveImage"}, "widgets_values": ["flux_lora_loss_plot"]}, {"id": 81, "type": "SomethingToString", "pos": {"0": 2782, "1": 851}, "size": {"0": 315, "1": 82}, "flags": {"collapsed": true, "pinned": true}, "order": 44, "mode": 0, "inputs": [{"name": "input", "type": "*", "link": 235}], "outputs": [{"name": "STRING", "type": "STRING", "links": [117], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "SomethingToString"}, "widgets_values": ["steps ", ""]}, {"id": 80, "type": "AddLabel", "pos": {"0": 3011, "1": 854}, "size": {"0": 315, "1": 274}, "flags": {"collapsed": true, "pinned": true}, "order": 50, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 119}, {"name": "caption", "type": "STRING", "link": null, "widget": {"name": "caption"}}, {"name": "text", "type": "STRING", "link": 117, "widget": {"name": "text"}}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [201], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "AddLabel"}, "widgets_values": [10, 2, 48, 32, "white", "black", "FreeMono.ttf", "Text", "up", ""]}, {"id": 122, "type": "SetNode", "pos": {"0": 3209, "1": 856}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 54, "mode": 0, "inputs": [{"name": "IMAGE", "type": "IMAGE", "link": 201}], "outputs": [{"name": "*", "type": "*", "links": null}], "title": "Set_Sampler2", "properties": {"previousName": "Sampler2"}, "widgets_values": ["Sampler2"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 46, "type": "PreviewImage", "pos": {"0": 2563, "1": 75}, "size": {"0": 1041.0751953125, "1": 707.7618408203125}, "flags": {"pinned": true}, "order": 49, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 70}], "outputs": [], "properties": {"Node name for S&R": "PreviewImage"}, "widgets_values": [], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 47, "type": "FluxTrainSave", "pos": {"0": 2950, "1": -112}, "size": {"0": 337.2362976074219, "1": 122}, "flags": {"pinned": true}, "order": 43, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 222}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [227], "slot_index": 0, "shape": 3}, {"name": "lora_path", "type": "STRING", "links": null, "shape": 3}, {"name": "steps", "type": "INT", "links": [], "slot_index": 2, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainSave"}, "widgets_values": [true, false], "color": "#232", "bgcolor": "#353"}, {"id": 125, "type": "SetNode", "pos": {"0": 4236, "1": 868}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 62, "mode": 0, "inputs": [{"name": "IMAGE", "type": "IMAGE", "link": 204}], "outputs": [{"name": "*", "type": "*", "links": [], "slot_index": 0}], "title": "Set_Sampler3", "properties": {"previousName": "Sampler3"}, "widgets_values": ["Sampler3"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 101, "type": "VisualizeLoss", "pos": {"0": 3803, "1": 955}, "size": {"0": 254.40000915527344, "1": 198}, "flags": {"pinned": true}, "order": 51, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 166}], "outputs": [{"name": "plot", "type": "IMAGE", "links": [165], "slot_index": 0, "shape": 3}, {"name": "loss_list", "type": "FLOAT", "links": null, "shape": 3}], "properties": {"Node name for S&R": "VisualizeLoss"}, "widgets_values": ["seaborn-v0_8-dark-palette", 100, true, 768, 512, false]}, {"id": 102, "type": "SaveImage", "pos": {"0": 4138, "1": 926}, "size": {"0": 574.23046875, "1": 414.46881103515625}, "flags": {"pinned": true}, "order": 55, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 165}], "outputs": [], "properties": {"Node name for S&R": "SaveImage"}, "widgets_values": ["flux_lora_loss_plot"]}, {"id": 90, "type": "SaveImage", "pos": {"0": 5272, "1": 933.4569091796875}, "size": {"0": 574.23046875, "1": 414.46881103515625}, "flags": {"pinned": true}, "order": 68, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 138}], "outputs": [], "properties": {"Node name for S&R": "SaveImage"}, "widgets_values": ["flux_lora_loss_plot"]}, {"id": 84, "type": "SomethingToString", "pos": {"0": 5000, "1": 860.4569091796875}, "size": {"0": 315, "1": 82}, "flags": {"collapsed": true, "pinned": true}, "order": 61, "mode": 0, "inputs": [{"name": "input", "type": "*", "link": 215}], "outputs": [{"name": "STRING", "type": "STRING", "links": [124], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "SomethingToString"}, "widgets_values": ["steps ", ""]}, {"id": 85, "type": "AddLabel", "pos": {"0": 5210, "1": 857.4569091796875}, "size": {"0": 315, "1": 274}, "flags": {"collapsed": true, "pinned": true}, "order": 67, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 126}, {"name": "caption", "type": "STRING", "link": null, "widget": {"name": "caption"}}, {"name": "text", "type": "STRING", "link": 124, "widget": {"name": "text"}}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [207], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "AddLabel"}, "widgets_values": [10, 2, 48, 32, "white", "black", "FreeMono.ttf", "Text", "up", ""]}, {"id": 127, "type": "SetNode", "pos": {"0": 5380, "1": 858.4569091796875}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 71, "mode": 0, "inputs": [{"name": "IMAGE", "type": "IMAGE", "link": 207}], "outputs": [{"name": "*", "type": "*", "links": null}], "title": "Set_Sampler4", "properties": {"previousName": "Sampler4"}, "widgets_values": ["Sampler4"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 66, "type": "PreviewImage", "pos": {"0": 4759, "1": 95.45687866210938}, "size": {"0": 1080.4327392578125, "1": 711.6444702148438}, "flags": {"pinned": true}, "order": 66, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 95}], "outputs": [], "properties": {"Node name for S&R": "PreviewImage"}, "widgets_values": [], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 64, "type": "FluxTrainLoop", "pos": {"0": 4756, "1": -106.54312133789062}, "size": {"0": 348.72076416015625, "1": 78}, "flags": {"pinned": true}, "order": 57, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 226}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [232], "slot_index": 0, "shape": 3}, {"name": "steps", "type": "INT", "links": [215], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainLoop"}, "widgets_values": [250], "color": "#232", "bgcolor": "#353"}, {"id": 134, "type": "FluxTrainSave", "pos": {"0": 5139, "1": -105.54312133789062}, "size": {"0": 305.3969421386719, "1": 122}, "flags": {"pinned": true}, "order": 60, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 232}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [233], "slot_index": 0, "shape": 3}, {"name": "lora_path", "type": "STRING", "links": null, "shape": 3}, {"name": "steps", "type": "INT", "links": [], "slot_index": 2, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainSave"}, "widgets_values": [true, false], "color": "#232", "bgcolor": "#353"}, {"id": 65, "type": "FluxTrainValidate", "pos": {"0": 5488, "1": -99.54312133789062}, "size": {"0": 312.3999938964844, "1": 46}, "flags": {"pinned": true}, "order": 63, "mode": 0, "inputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "link": 233}, {"name": "validation_settings", "type": "VALSETTINGS", "link": 94}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [217, 229], "slot_index": 0, "shape": 3}, {"name": "validation_images", "type": "IMAGE", "links": [95, 126], "slot_index": 1, "shape": 3}], "properties": {"Node name for S&R": "FluxTrainValidate"}, "widgets_values": [], "color": "#232", "bgcolor": "#353"}, {"id": 68, "type": "GetNode", "pos": {"0": 5561, "1": 9.456878662109375}, "size": {"0": 210, "1": 58}, "flags": {"collapsed": true, "pinned": true}, "order": 12, "mode": 0, "inputs": [], "outputs": [{"name": "VALSETTINGS", "type": "VALSETTINGS", "links": [94], "slot_index": 0}], "title": "Get_validation_settings", "properties": {}, "widgets_values": ["validation_settings"], "color": "#232", "bgcolor": "#353"}, {"id": 111, "type": "TrainDatasetAdd", "pos": {"0": -481, "1": -111}, "size": {"0": 267.5897521972656, "1": 318}, "flags": {"pinned": true}, "order": 26, "mode": 4, "inputs": [{"name": "dataset_config", "type": "JSON", "link": 253}], "outputs": [{"name": "dataset", "type": "JSON", "links": [249], "slot_index": 0, "shape": 3}], "title": "Train 768x768 Dataset", "properties": {"Node name for S&R": "TrainDatasetAdd"}, "widgets_values": [768, 768, 1, "../training/input/", "LoraTrigger", true, false, 1, 256, 1024], "color": "#149454", "bgcolor": "#008040"}, {"id": 108, "type": "TrainDatasetGeneralConfig", "pos": {"0": -1178, "1": 177}, "size": {"0": 315, "1": 154}, "flags": {"pinned": true}, "order": 13, "mode": 0, "inputs": [], "outputs": [{"name": "dataset_general", "type": "JSON", "links": [185], "slot_index": 0, "shape": 3}], "properties": {"Node name for S&R": "TrainDatasetGeneralConfig"}, "widgets_values": [false, true, false, 0, false], "color": "#232", "bgcolor": "#353"}, {"id": 113, "type": "Note", "pos": {"0": -1189.3692626953125, "1": -106.29979705810547}, "size": {"0": 327.63427734375, "1": 168.70933532714844}, "flags": {"pinned": true}, "order": 14, "mode": 0, "inputs": [], "outputs": [], "title": "Datasets Note", "properties": {"text": ""}, "widgets_values": ["For multiresolution training, input same source directory with different dataset resolution. From what I hear, Flux likes multiple resolutions.\n\nFor single resolution training, just add single dataset.\n\nVery important: remember to set the directory where input images is located (../training/input/ by default) and the LoraTrigger word if you want one."], "color": "#ff9414", "bgcolor": "#ff8000"}, {"id": 115, "type": "Note", "pos": {"0": 228, "1": -114}, "size": {"0": 464.1640930175781, "1": 101.32028198242188}, "flags": {"pinned": true}, "order": 15, "mode": 0, "inputs": [], "outputs": [], "title": "Note on FLUX model", "properties": {"text": ""}, "widgets_values": ["You can use same models as you use for inference in Comfy. When fp8_base is enabled, the model is downcasted to torch.float_e4m3fn on initialize, meaning if you load fp8 model here it should also be in same format.\n\nDownload the flux1-dev-fp8.safetensors transformer from this link:\nhttps://huggingface.co/Kijai/flux-fp8/tree/main "], "color": "#ff9414", "bgcolor": "#ff8000"}, {"id": 135, "type": "Note", "pos": {"0": 226, "1": 300}, "size": {"0": 401.9402160644531, "1": 63.765438079833984}, "flags": {"pinned": true}, "order": 16, "mode": 0, "inputs": [], "outputs": [], "title": "Note on Optimizers", "properties": {}, "widgets_values": ["You can use Adafactor Optimizer node (suggested) or use the other \"Optimizer Config\" node that allows you to choose the following optimizers: Adamw8bit, Adamw, Prodigy and Came Optimizers."], "color": "#ff9414", "bgcolor": "#ff8000"}, {"id": 116, "type": "Note", "pos": {"0": 802, "1": -113}, "size": {"0": 572.6136474609375, "1": 105.09221649169922}, "flags": {"pinned": true}, "order": 17, "mode": 0, "inputs": [], "outputs": [], "title": "Note on Training and Validation", "properties": {"text": "\n"}, "widgets_values": ["Validation sampling settings are set here for all the 4 sampler nodes.\nRemeber to write a prompt in the \"Init Flux LoRA Training\" node (at the bottom). You can generate more than one image just separating each image's prompt with \"|\".\nIn the 4 Train-groups, the Steps in each Train Loop must be 1/4 of what you set in \"max_train_steps\".\n\nFor training settings in the \"Init Flux LoRA Training\" node visit: https://github.com/kohya-ss/sd-scripts"], "color": "#ff9414", "bgcolor": "#ff8000"}, {"id": 140, "type": "Note Plus (mtb)", "pos": {"0": -1780.3626708984375, "1": -174.3422088623047}, "size": {"0": 564.8421020507812, "1": 583.4563598632812}, "flags": {"pinned": true}, "order": 18, "mode": 0, "inputs": [], "outputs": [], "title": "Unnamed", "properties": {}, "widgets_values": ["

FLUX LoRA Trainer on ComfyUI

\n
\n

This workflow is based on the incredible work by Kijai (https://github.com/kijai/ComfyUI-FluxTrainer) who created the training nodes for ComfyUI based on Kohya_ss (https://github.com/kohya-ss/sd-scripts) work. All credits go to them. Thanks also to u/tom83_be on Reddit who posted his installation and basic settings tips.

\n
----------------
\n

To train a LoRA (Low Rank Adaptation) for FLUX these are the steps you should follow before clicking on Queue:

\n1) Prepare learning data - that is an images set (min 10, 20-30 is fine, but for some specifici LoRA's the more is better)
\n2) You don't need to create the caption .txt files, FLUX model's LoRA's can be trained on images only.
\n3) Check you have set the input (training images) and the output (saved LoRA's) folders correctly.
\n4) Set your LoraTrigger word (optional)
\n5) Add a prompt (or multiple prompts) for Training Validation in the \"Init Flux LoRA Training\" node, at the bottom.
\n6) Adjust training settings (or leave default ones)

\nNow click \"Queue\" and wait a few hours...
\nAt the end of the trainig you will have a few different LoRA's, chose the best one (usually the secondo or the third in my experience) and enjoy it in your next workflow!
\n

", "markdown", "", "one_dark"], "color": "#f38414", "bgcolor": "#df7000", "shape": 1}, {"id": 88, "type": "Display Any (rgthree)", "pos": {"0": 1172, "1": 181}, "size": {"0": 210, "1": 76}, "flags": {"pinned": true}, "order": 32, "mode": 0, "inputs": [{"name": "source", "type": "*", "link": 182, "dir": 3}], "outputs": [], "title": "Number of epochs", "properties": {"Node name for S&R": "Display Any (rgthree)"}, "widgets_values": ["91"], "color": "#232", "bgcolor": "#353"}, {"id": 112, "type": "TrainDatasetAdd", "pos": {"0": -130, "1": -113}, "size": {"0": 259.5897521972656, "1": 318}, "flags": {"pinned": true}, "order": 28, "mode": 4, "inputs": [{"name": "dataset_config", "type": "JSON", "link": 250}], "outputs": [{"name": "dataset", "type": "JSON", "links": [256], "slot_index": 0, "shape": 3}], "title": "Train 1024x1024 Dataset", "properties": {"Node name for S&R": "TrainDatasetAdd"}, "widgets_values": [1024, 1024, 1, "input/shih_training", "shihtoken", true, false, 1, 256, 1024], "color": "#149494", "bgcolor": "#008080"}, {"id": 74, "type": "Display Any (rgthree)", "pos": {"0": 5990.5439453125, "1": 519.4569091796875}, "size": {"0": 358.62896728515625, "1": 76}, "flags": {"pinned": true}, "order": 70, "mode": 0, "inputs": [{"name": "source", "type": "*", "link": 236, "dir": 3}], "outputs": [], "properties": {"Node name for S&R": "Display Any (rgthree)"}, "widgets_values": ["\"flux_loras/flux_bombo_rank16_bf16.safetensors\""]}, {"id": 109, "type": "TrainDatasetAdd", "pos": {"0": -829.0809326171875, "1": -109.45999908447266}, "size": {"0": 281.5897521972656, "1": 318}, "flags": {"pinned": true}, "order": 23, "mode": 0, "inputs": [{"name": "dataset_config", "type": "JSON", "link": 185}], "outputs": [{"name": "dataset", "type": "JSON", "links": [252], "slot_index": 0, "shape": 3}], "title": "Train 512x512 Dataset", "properties": {"Node name for S&R": "TrainDatasetAdd"}, "widgets_values": [512, 512, 4, "input/hodi", "h0d1", true, false, 1, 256, 1024], "color": "#232", "bgcolor": "#353"}, {"id": 107, "type": "InitFluxLoRATraining", "pos": {"0": 891, "1": 343}, "size": {"0": 449.94586181640625, "1": 853.0341186523438}, "flags": {"pinned": true}, "order": 30, "mode": 0, "inputs": [{"name": "flux_models", "type": "TRAIN_FLUX_MODELS", "link": 179}, {"name": "dataset", "type": "JSON", "link": 255}, {"name": "optimizer_settings", "type": "ARGS", "link": 248}, {"name": "resume_args", "type": "ARGS", "link": null}], "outputs": [{"name": "network_trainer", "type": "NETWORKTRAINER", "links": [181], "shape": 3}, {"name": "epochs_count", "type": "INT", "links": [182], "shape": 3}, {"name": "args", "type": "KOHYA_ARGS", "links": [183], "shape": 3}], "properties": {"Node name for S&R": "InitFluxLoRATraining"}, "widgets_values": ["flux_bombo", "flux_loras", 16, 80, 0.00039999999999999996, 1000, true, "disk", "disk", false, "logit_normal", 0, 1, 1.29, "shift", 1, "raw", 1, 3.1582000000000003, false, false, "bf16", "bf16", "xformers", "h0d1 he is smoking in a nightclub with pistol pointed at the ceiling. he has a private booth with bottles of whiskey and vodka and several female blonde women cats", "", "use_fp8", 0], "color": "#232", "bgcolor": "#353"}], "links": [[8, 8, 1, 9, 0, "IMAGE"], [70, 45, 1, 46, 0, "IMAGE"], [90, 60, 1, 61, 0, "IMAGE"], [94, 68, 0, 65, 1, "VALSETTINGS"], [95, 65, 1, 66, 0, "IMAGE"], [111, 79, 0, 78, 2, "STRING"], [112, 8, 1, 78, 0, "IMAGE"], [117, 81, 0, 80, 2, "STRING"], [119, 45, 1, 80, 0, "IMAGE"], [121, 82, 0, 83, 2, "STRING"], [122, 60, 1, 83, 0, "IMAGE"], [124, 84, 0, 85, 2, "STRING"], [126, 65, 1, 85, 0, "IMAGE"], [138, 70, 0, 90, 0, "IMAGE"], [161, 97, 0, 98, 0, "IMAGE"], [162, 4, 0, 97, 0, "NETWORKTRAINER"], [163, 99, 0, 100, 0, "IMAGE"], [164, 44, 0, 99, 0, "NETWORKTRAINER"], [165, 101, 0, 102, 0, "IMAGE"], [166, 59, 0, 101, 0, "NETWORKTRAINER"], [179, 2, 0, 107, 0, "TRAIN_FLUX_MODELS"], [181, 107, 0, 4, 0, "NETWORKTRAINER"], [182, 107, 1, 88, 0, "*"], [183, 107, 2, 105, 0, "*"], [185, 108, 0, 109, 0, "JSON"], [195, 119, 0, 117, 0, "IMAGE"], [199, 120, 3, 117, 1, "INT"], [200, 78, 0, 121, 0, "*"], [201, 80, 0, 122, 0, "*"], [202, 123, 0, 119, 0, "IMAGE"], [203, 124, 0, 119, 1, "IMAGE"], [204, 83, 0, 125, 0, "*"], [206, 126, 0, 119, 2, "IMAGE"], [207, 85, 0, 127, 0, "*"], [208, 128, 0, 119, 3, "IMAGE"], [209, 123, 0, 120, 0, "IMAGE"], [210, 117, 0, 129, 0, "IMAGE"], [214, 129, 0, 130, 0, "IMAGE"], [215, 64, 1, 84, 0, "*"], [217, 65, 0, 70, 0, "NETWORKTRAINER"], [218, 4, 0, 14, 0, "NETWORKTRAINER"], [219, 8, 0, 44, 0, "NETWORKTRAINER"], [220, 4, 1, 79, 0, "*"], [221, 14, 0, 8, 0, "NETWORKTRAINER"], [222, 44, 0, 47, 0, "NETWORKTRAINER"], [223, 45, 0, 59, 0, "NETWORKTRAINER"], [224, 59, 0, 62, 0, "NETWORKTRAINER"], [225, 62, 0, 60, 0, "NETWORKTRAINER"], [226, 60, 0, 64, 0, "NETWORKTRAINER"], [227, 47, 0, 45, 0, "NETWORKTRAINER"], [229, 65, 0, 133, 0, "NETWORKTRAINER"], [231, 133, 0, 129, 2, "STRING"], [232, 64, 0, 134, 0, "NETWORKTRAINER"], [233, 134, 0, 65, 0, "NETWORKTRAINER"], [234, 59, 1, 82, 0, "*"], [235, 44, 1, 81, 0, "*"], [236, 133, 2, 74, 0, "*"], [242, 40, 0, 8, 1, "VALSETTINGS"], [243, 37, 0, 38, 0, "VALSETTINGS"], [244, 48, 0, 45, 1, "VALSETTINGS"], [245, 63, 0, 60, 1, "VALSETTINGS"], [246, 95, 0, 136, 1, "*"], [247, 114, 0, 136, 0, "*"], [248, 136, 0, 107, 2, "ARGS"], [249, 111, 0, 137, 0, "*"], [250, 137, 0, 112, 0, "JSON"], [252, 109, 0, 138, 0, "*"], [253, 138, 0, 111, 0, "JSON"], [255, 139, 0, 107, 1, "JSON"], [256, 112, 0, 139, 0, "*"]], "groups": [{"title": "End of LoRA Training", "bounding": [5882, -179, 1248, 813], "color": "#ff8080", "font_size": 24, "flags": {}}, {"title": "Sanity check", "bounding": [-1198, 376, 1379, 560], "color": "#3f789e", "font_size": 24, "flags": {}}, {"title": "Dataset", "bounding": [-1201, -187, 1389, 545], "color": "#00ff00", "font_size": 24, "flags": {}}, {"title": "LoRA Training 04", "bounding": [4743, -181, 1118, 1541], "color": "#3f789e", "font_size": 24, "flags": {}}, {"title": "LoRA Training 03", "bounding": [3639, -181, 1094, 1538], "color": "#3f789e", "font_size": 24, "flags": {}}, {"title": "LoRA Training 02", "bounding": [2554, -184, 1074, 1543], "color": "#3f789e", "font_size": 24, "flags": {}}, {"title": "FLUX model - Optimizer select - Training and Validate settings", "bounding": [195, -187, 1199, 1405], "color": "#b06634", "font_size": 24, "flags": {}}, {"title": "LoRA Training 01", "bounding": [1403, -188, 1140, 1549], "color": "#3f789e", "font_size": 24, "flags": {}}], "config": {}, "extra": {"ds": {"scale": 0.5054470284993074, "offset": [711.8774949129566, 346.6150349665473]}}, "version": 0.4} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4ad5f3b8..23ac0295 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,10 @@ kornia>=0.7.1 spandrel soundfile av + +# memedeck dependencies +pika +python-dotenv +pillow +azure-storage-blob +cairosvg diff --git a/styles/default.csv b/styles/default.csv new file mode 100644 index 00000000..b7977e59 --- /dev/null +++ b/styles/default.csv @@ -0,0 +1,3 @@ +name,prompt,negative_prompt +❌Low Token,,"embedding:EasyNegative, NSFW, Cleavage, Pubic Hair, Nudity, Naked, censored" +✅Line Art / Manga,"(Anime Scene, Toonshading, Satoshi Kon, Ken Sugimori, Hiromu Arakawa:1.2), (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows, glow effects, humorous illustration, big depth of field, Masterpiece, colors, concept art, trending on artstation, Vivid colors, dramatic",