added MemedeckWorker logic to the repo

This commit is contained in:
Ubuntu 2024-10-15 19:11:27 +00:00 committed by Ubuntu
parent f86c724ef2
commit f93fe6b3fc
39 changed files with 1027 additions and 1079 deletions

9
.gitignore vendored
View File

@ -21,3 +21,12 @@ venv/
*.log
web_custom_versions/
.DS_Store
.env
models
custom_nodes
models/
flux_loras/
custom_nodes/

50
MEMEDECK.md Normal file
View File

@ -0,0 +1,50 @@
# Commands to set up the environment.
## Enable MIGs on NVIDA GPU
<!-- link to the guide -->
https://www.seimaxim.com/kb/gpu/nvidia-a100-mig-cheat-sheat
```zsh
sudo nvidia-smi -i 0 -mig 1
sudo nvidia-smi mig -cgi 9,9 -C
```
Start comfy with MIGs
```zsh
CUDA_VISIBLE_DEVICES=MIG-0 python main.py --port 5000 --listen 0.0.0.0 --cuda-device 0 --preview-method auto
```
## Tunnel remote server port to local machine port
### On the server
```zsh
sudo nano /etc/ssh/sshd_config
# uncomment this line in the sshd_config file
GatewayPorts yes
```
Then restart the machine.
## On your local machine
```zsh
sudo ssh -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40
```
## install autossh (maybe not needed)
```zsh
brew install autossh
autossh -M 0 -o "ServerAliveInterval 30" -o "ServerAliveCountMax 3" -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40
```
<!-- sudo ssh -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40 -->
## On the server
This is to allow the port to be accessed from the local machine.
```zsh
sudo iptables -A INPUT -p tcp --dport 9090 -j ACCEPT
```

View File

@ -213,10 +213,9 @@ def load_lora(lora, to_load, log_missing=True):
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
if log_missing:
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
# for x in lora.keys():
# if x not in loaded_keys:
# logging.warning("lora key not loaded: {}".format(x))
return patch_dict

View File

@ -1,155 +0,0 @@
class Example:
"""
A example node
Class methods
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
Return a dictionary which contains config for all input fields.
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
The type can be a list for selection.
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Second value is a config for type "INT", "STRING" or "FLOAT".
"""
return {
"required": {
"image": ("IMAGE",),
"int_field": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64, #Slider's step
"display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}),
"float_field": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 10.0,
"step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!",
"lazy": True
}),
},
}
RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test"
#OUTPUT_NODE = False
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
if print_to_screen == "enable":
return ["int_field", "float_field", "string_field"]
else:
return []
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:
string_field aka input text: {string_field}
int_field: {int_field}
float_field: {float_field}
""")
#do some processing on the image, in this example I just invert it
image = 1.0 - image
return (image,)
"""
The node will always be re executed if any of the inputs change but
this method can be used to force the node to execute again even when the inputs don't change.
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
executed, if it is different the node will be executed again.
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
# WEB_DIRECTORY = "./somejs"
# Add custom API routes, using router
from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/hello")
async def get_hello(request):
return web.json_response("hello")
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Example": Example
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Example": "Example Node"
}

View File

@ -1,44 +0,0 @@
from PIL import Image
import numpy as np
import comfy.utils
import time
#You can use this node to save full size images through the websocket, the
#images will be sent in exactly the same format as the image previews: as
#binary images on the websocket with a 8 byte header indicating the type
#of binary message (first 4 bytes) and the image format (next 4 bytes).
#Note that no metadata will be put in the images saved with this node.
class SaveImageWebsocket:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),}
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "api/image"
def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0])
step = 0
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
step += 1
return {}
@classmethod
def IS_CHANGED(s, images):
return time.time()
NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket,
}

View File

@ -245,7 +245,7 @@ def format_value(x):
else:
return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
def execute(server, memedeck_worker, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -253,10 +253,15 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_memedeck = extra_data.get("is_memedeck", False)
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
if memedeck_worker.ws_id is not None:
memedeck_worker.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, memedeck_worker.ws_id)
# elif is_memedeck:
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
@ -284,10 +289,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
has_subgraph = False
else:
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
if server.client_id is not None and not is_memedeck:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
elif is_memedeck:
memedeck_worker.last_node_id = display_node_id
memedeck_worker.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, memedeck_worker.ws_id)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
@ -319,6 +327,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
memedeck_worker.send_sync("execution_error", mes, memedeck_worker.ws_id)
return ExecutionBlocker(None)
else:
return block
@ -335,8 +344,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
},
"output": output_ui
})
if server.client_id is not None:
if server.client_id is not None and not is_memedeck:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
elif is_memedeck:
memedeck_worker.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, memedeck_worker.ws_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
@ -414,9 +426,12 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, lru_size=None):
def __init__(self, server, memedeck_worker, lru_size=None):
self.lru_size = lru_size
self.server = server
self.memedeck_worker = memedeck_worker
# set logging level
logging.basicConfig(level=logging.INFO)
self.reset()
def reset(self):
@ -432,6 +447,9 @@ class PromptExecutor:
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
elif self.memedeck_worker.ws_id is not None:
self.memedeck_worker.send_sync(event, data, self.memedeck_worker.ws_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
@ -466,8 +484,12 @@ class PromptExecutor:
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
if "ws_id" in extra_data:
self.memedeck_worker.ws_id = extra_data["ws_id"]
self.memedeck_worker.websocket_node_id = extra_data["websocket_node_id"]
else:
self.server.client_id = None
self.memedeck_worker.ws_id = None
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
@ -501,7 +523,7 @@ class PromptExecutor:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
result, error, ex = execute(self.server, self.memedeck_worker, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -512,7 +534,7 @@ class PromptExecutor:
execution_list.complete_node_execution()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
self.add_message("execution_success", { "prompt_id": prompt_id, 'ws_id': self.memedeck_worker.ws_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
@ -527,6 +549,7 @@ class PromptExecutor:
"meta": meta_outputs,
}
self.server.last_node_id = None
self.memedeck_worker.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@ -869,8 +892,9 @@ def validate_prompt(prompt):
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue:
def __init__(self, server):
def __init__(self, server, memedeck_worker):
self.server = server
self.memedeck_worker = memedeck_worker
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
@ -884,6 +908,7 @@ class PromptQueue:
with self.mutex:
heapq.heappush(self.queue, item)
self.server.queue_updated()
self.memedeck_worker.queue_updated()
self.not_empty.notify()
def get(self, timeout=None):
@ -897,6 +922,7 @@ class PromptQueue:
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
self.server.queue_updated()
self.memedeck_worker.queue_updated()
return (item, i)
class ExecutionStatus(NamedTuple):
@ -922,6 +948,7 @@ class PromptQueue:
}
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
self.memedeck_worker.queue_updated()
def get_current_queue(self):
with self.mutex:
@ -938,6 +965,7 @@ class PromptQueue:
with self.mutex:
self.queue = []
self.server.queue_updated()
self.memedeck_worker.queue_updated()
def delete_queue_item(self, function):
with self.mutex:
@ -949,6 +977,8 @@ class PromptQueue:
self.queue.pop(x)
heapq.heapify(self.queue)
self.server.queue_updated()
self.memedeck_worker.queue_updated()
return True
return False

90
main.py
View File

@ -53,6 +53,14 @@ def apply_custom_paths():
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
# ---------------------------------------------------------------------------------------
# memedeck imports
# ---------------------------------------------------------------------------------------
from memedeck import MemedeckWorker
import sys
sys.stdout = open(os.devnull, 'w') # disable all print statements
# ---------------------------------------------------------------------------------------
def execute_prestartup_script():
def execute_script(script_path):
@ -153,9 +161,11 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server_instance):
def prompt_worker(q, server, memedeck_worker):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
e = execution.PromptExecutor(server, memedeck_worker, lru_size=args.cache_lru)
# threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@ -209,12 +219,12 @@ def prompt_worker(q, server_instance):
need_gc = False
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
async def run(server_instance, memedeck_worker, address='', port=8188, verbose=True, call_on_start=None):
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop(), memedeck_worker.publish_loop()
)
@ -224,11 +234,34 @@ def hijack_progress(server_instance):
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
server_instance.send_sync("progress", progress, server_instance.client_id)
if memedeck_worker.ws_id is not None:
memedeck_worker.send_sync("progress", {"value": value, "max": total, "prompt_id": memedeck_worker.last_prompt_id, "node": server_instance.last_node_id}, memedeck_worker.ws_id)
if preview_image is not None:
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
if memedeck_worker.ws_id is not None:
memedeck_worker.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, memedeck_worker.ws_id)
comfy.utils.set_progress_bar_global_hook(hook)
# async def run(server, memedeck_worker, address='', port=8188, verbose=True, call_on_start=None):
# addresses = []
# for addr in address.split(","):
# addresses.append((addr, port))
# # add memedeck worker publish loop
# await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop(), memedeck_worker.publish_loop())
# def hijack_progress(server, memedeck_worker):
# def hook(value, total, preview_image):
# comfy.model_management.throw_exception_if_processing_interrupted()
# server.send_sync("progress", {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}, server.client_id)
# if memedeck_worker.ws_id is not None:
# memedeck_worker.send_sync("progress", {"value": value, "max": total, "prompt_id": memedeck_worker.last_prompt_id, "node": server.last_node_id}, memedeck_worker.ws_id)
# if preview_image is not None:
# server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
# if memedeck_worker.ws_id is not None:
# memedeck_worker.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, memedeck_worker.ws_id)
# comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp():
temp_dir = folder_paths.get_temp_directory()
@ -258,16 +291,54 @@ def start_comfyui(asyncio_loop=None):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
memedeck_worker = MemedeckWorker(asyncio_loop)
q = execution.PromptQueue(prompt_server, memedeck_worker)
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# server = server.PromptServer(loop)
# q = execution.PromptQueue(server, memedeck_worker)
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
utils.extra_config.load_extra_path_config(config_path)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
cuda_malloc_warning()
prompt_server.add_routes()
hijack_progress(prompt_server)
hijack_progress(server, memedeck_worker)
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, server, memedeck_worker)).start()
threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
# set logging level to info
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
if args.quick_test_for_ci:
exit(0)
@ -287,6 +358,7 @@ def start_comfyui(asyncio_loop=None):
async def start_all():
await prompt_server.setup()
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
await run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
# Returning these so that other code can integrate with the ComfyUI loop and server
return asyncio_loop, prompt_server, start_all
@ -298,6 +370,8 @@ if __name__ == "__main__":
event_loop, _, start_all_func = start_comfyui()
try:
event_loop.run_until_complete(start_all_func())
# event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
logging.info("\nStopped server")

266
memedeck-v1.py Normal file
View File

@ -0,0 +1,266 @@
import asyncio
import base64
from io import BytesIO
import os
import logging
import signal
import struct
from typing import Optional
import uuid
from PIL import Image, ImageOps
from functools import partial
import pika
import json
import requests
import aiohttp
# load from env file
# load from .env file
from dotenv import load_dotenv
load_dotenv()
amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev'
# define the enum in python
from enum import Enum
class QueueProgressKind(Enum):
# make json serializable
ImageGenerated = "image_generated"
ImageGenerating = "image_generating"
SamePrompt = "same_prompt"
FaceswapGenerated = "faceswap_generated"
FaceswapGenerating = "faceswap_generating"
Failed = "failed"
class MemedeckWorker:
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
class JsonEventTypes(Enum):
PROGRESS = "progress"
EXECUTING = "executing"
EXECUTED = "executed"
ERROR = "error"
STATUS = "status"
"""
MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api
it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api
"""
def __init__(self, loop):
MemedeckWorker.instance = self
# set logging level to info
logging.getLogger().setLevel(logging.INFO)
self.active_tasks_map = {}
self.current_task = None
self.client_id = None
self.ws_id = None
self.websocket_node_id = None
self.current_node = None
self.current_progress = 0
self.current_context = None
self.loop = loop
self.messages = asyncio.Queue()
self.http_client = None
self.prompt_queue = None
self.validate_prompt = None
self.last_prompt_id = None
self.amqp_url = amqp_addr
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n")
def on_connection_open(self, connection):
self.connection = connection
self.connection.channel(on_open_callback=self.on_channel_open)
def on_channel_open(self, channel):
self.channel = channel
# only consume one message at a time
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
self.channel.queue_declare(queue=self.queue_name, durable=True)
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
def start(self, prompt_queue, validate_prompt):
self.prompt_queue = prompt_queue
self.validate_prompt = validate_prompt
parameters = pika.URLParameters(self.amqp_url)
logging.getLogger('pika').setLevel(logging.WARNING) # supress all logs from pika
self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open)
try:
self.connection.ioloop.start()
except KeyboardInterrupt:
self.connection.close()
self.connection.ioloop.start()
def on_message_received(self, channel, method, properties, body):
decoded_string = body.decode('utf-8')
json_object = json.loads(decoded_string)
payload = json_object[1]
# execute the task
prompt = payload["nodes"]
valid = self.validate_prompt(prompt)
self.current_node = None
self.current_progress = 0
self.websocket_node_id = None
self.ws_id = payload["source_ws_id"]
self.current_context = payload["req_ctx"]
for node in prompt: # search through prompt nodes for websocket_node_id
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
self.websocket_node_id = node
break
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.active_tasks_map[payload["source_ws_id"]] = {
"prompt_id": prompt_id,
"prompt": prompt,
"outputs_to_execute": outputs_to_execute,
"client_id": "memedeck-1",
"is_memedeck": True,
"websocket_node_id": self.websocket_node_id,
"ws_id": payload["source_ws_id"],
"context": payload["req_ctx"],
"current_node": None,
"current_progress": 0,
}
self.prompt_queue.put((0, prompt_id, prompt, {
"client_id": "memedeck-1",
'is_memedeck': True,
'websocket_node_id': self.websocket_node_id,
'ws_id': payload["source_ws_id"],
'context': payload["req_ctx"]
}, outputs_to_execute))
self.set_last_prompt_id(prompt_id)
channel.basic_ack(delivery_tag=method.delivery_tag) # ack the task
else:
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
# --------------------------------------------------
# callbacks for the prompt queue
# --------------------------------------------------
def queue_updated(self):
# print json of the queue info but only print the first 100 lines
info = self.get_queue_info()
# update_type = info['']
# self.send_sync("status", { "status": self.get_queue_info() })
def get_queue_info(self):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
def set_last_prompt_id(self, prompt_id):
self.last_prompt_id = prompt_id
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(*msg)
async def send(self, event, data, sid=None):
current_task = self.active_tasks_map.get(sid)
if current_task is None or current_task['ws_id'] != sid:
return
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: # preview and unencoded images are sent here
self.logger.info(f"[memedeck]: sending image preview for {sid}")
await self.send_preview(data, sid=current_task['ws_id'], progress=current_task['current_progress'], context=current_task['context'])
else: # send json data / text data
if event == "executing":
current_task['current_node'] = data['node']
elif event == "executed":
self.logger.info(f"---> [memedeck]: executed event for {sid}")
prompt_id = data['prompt_id']
if prompt_id in self.active_tasks_map:
del self.active_tasks_map[prompt_id]
elif event == "progress":
if current_task['current_node'] == current_task['websocket_node_id']: # if the node is the websocket node, then set the progress to 100
current_task['current_progress'] = 100
else: # if the node is not the websocket node, then set the progress to the progress from the node
current_task['current_progress'] = data['value'] / data['max'] * 100
if current_task['current_progress'] == 100 and current_task['current_node'] != current_task['websocket_node_id']:
# in case the progress is 100 but the node is not the websocket node, then set the progress to 95
current_task['current_progress'] = 95 # this allows the full resolution image to be sent on the 100 progress event
if data['value'] == 1: # if the value is 1, then send started to api
start_data = {
"ws_id": current_task['ws_id'],
"status": "started",
"info": None,
}
await self.send_to_api(start_data)
elif event == "status":
self.logger.info(f"[memedeck]: sending status event: {data}")
self.active_tasks_map[sid] = current_task
async def send_preview(self, image_data, sid=None, progress=None, context=None):
# if self.current_progress is odd, then don't send the preview
if progress % 2 == 1:
return
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=100 if progress == 96 else 75, compress_level=1)
preview_bytes = bytesIO.getvalue()
ai_queue_progress = {
"ws_id": sid,
"kind": "image_generating" if progress < 100 else "image_generated",
"data": list(preview_bytes),
"progress": int(progress),
"context": context
}
await self.send_to_api(ai_queue_progress)
async def send_to_api(self, data):
if self.websocket_node_id is None: # check if the node is still running
logging.error(f"[memedeck]: websocket_node_id is None for {data['ws_id']}")
return
try:
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
await self.loop.run_in_executor(None, post_func)
except Exception as e:
self.logger.error(f"[memedeck]: error sending to api: {e}")

568
memedeck.py Normal file
View File

@ -0,0 +1,568 @@
import asyncio
import base64
from io import BytesIO
import os
import logging
import signal
import struct
import time
from typing import Optional
import uuid
from PIL import Image, ImageOps
from functools import partial
import pika
import json
import requests
import aiohttp
from dotenv import load_dotenv
load_dotenv()
amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev'
from enum import Enum
class QueueProgressKind(Enum):
ImageGenerated = "image_generated"
ImageGenerating = "image_generating"
SamePrompt = "same_prompt"
FaceswapGenerated = "faceswap_generated"
FaceswapGenerating = "faceswap_generating"
Failed = "failed"
class MemedeckWorker:
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
class JsonEventTypes(Enum):
PROGRESS = "progress"
EXECUTING = "executing"
EXECUTED = "executed"
ERROR = "error"
STATUS = "status"
"""
MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api
it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api
"""
def __init__(self, loop):
MemedeckWorker.instance = self
logging.getLogger().setLevel(logging.INFO)
self.loop = loop
self.messages = asyncio.Queue()
self.http_client = None
self.prompt_queue = None
self.validate_prompt = None
self.last_prompt_id = None
self.amqp_url = amqp_addr
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
# Internal job queue
self.internal_job_queue = asyncio.Queue()
# Dictionary to keep track of tasks by ws_id
self.tasks_by_ws_id = {}
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n")
def on_connection_open(self, connection):
self.connection = connection
self.connection.channel(on_open_callback=self.on_channel_open)
def on_channel_open(self, channel):
self.channel = channel
# only consume one message at a time
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
self.channel.queue_declare(queue=self.queue_name, durable=True)
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False)
def start(self, prompt_queue, validate_prompt):
self.prompt_queue = prompt_queue
self.validate_prompt = validate_prompt
# Start the process_job_queue task **after** prompt_queue is set
self.loop.create_task(self.process_job_queue())
parameters = pika.URLParameters(self.amqp_url)
logging.getLogger('pika').setLevel(logging.WARNING) # suppress all logs from pika
self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open)
try:
self.connection.ioloop.start()
except KeyboardInterrupt:
self.connection.close()
self.connection.ioloop.start()
def on_message_received(self, channel, method, properties, body):
decoded_string = body.decode('utf-8')
json_object = json.loads(decoded_string)
payload = json_object[1]
# execute the task
prompt = payload["nodes"]
valid = self.validate_prompt(prompt)
# Prepare task_info
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
task_info = {
"prompt_id": prompt_id,
"prompt": prompt,
"outputs_to_execute": outputs_to_execute,
"client_id": "memedeck-1",
"is_memedeck": True,
"websocket_node_id": None,
"ws_id": payload["source_ws_id"],
"context": payload["req_ctx"],
"current_node": None,
"current_progress": 0,
"delivery_tag": method.delivery_tag,
"task_status": "waiting",
}
# Find the websocket_node_id
for node in prompt:
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
task_info['websocket_node_id'] = node
break
if valid[0]:
# Enqueue the task into the internal job queue
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
else:
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
async def process_job_queue(self):
while True:
prompt_id, prompt, task_info = await self.internal_job_queue.get()
# Start a new coroutine for each task
self.loop.create_task(self.process_task(prompt_id, prompt, task_info))
async def process_task(self, prompt_id, prompt, task_info):
ws_id = task_info['ws_id']
# Add the task to tasks_by_ws_id
self.tasks_by_ws_id[ws_id] = task_info
# Put the prompt into the prompt_queue
self.prompt_queue.put((0, prompt_id, prompt, {
"client_id": task_info["client_id"],
'is_memedeck': task_info['is_memedeck'],
'websocket_node_id': task_info['websocket_node_id'],
'ws_id': task_info['ws_id'],
'context': task_info['context']
}, task_info['outputs_to_execute']))
# Acknowledge the message
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
self.logger.info(f"[memedeck]: Acked task {prompt_id} {ws_id}")
self.logger.info(f"[memedeck]: Started processing prompt {prompt_id}")
# Wait until the current task is completed
await self.wait_for_task_completion(ws_id)
# Task is done
self.internal_job_queue.task_done()
async def wait_for_task_completion(self, ws_id):
"""
Wait until the task with the given ws_id is completed.
"""
while ws_id in self.tasks_by_ws_id:
await asyncio.sleep(0.5)
# --------------------------------------------------
# callbacks for the prompt queue
# --------------------------------------------------
def queue_updated(self):
info = self.get_queue_info()
# self.send_sync("status", { "status": self.get_queue_info() })
def get_queue_info(self):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(*msg)
async def send(self, event, data, sid=None):
if sid is None:
self.logger.warning("Received event without sid")
return
# Retrieve the task based on sid
task = self.tasks_by_ws_id.get(sid)
if not task:
self.logger.warning(f"Received event {event} for unknown sid: {sid}")
return
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_preview(
data,
sid=sid,
progress=task['current_progress'],
context=task['context']
)
else:
# Send JSON data / text data
if event == "executing":
task['current_node'] = data['node']
task["task_status"] = "executing"
elif event == "progress":
if task['current_node'] == task['websocket_node_id']:
# If the node is the websocket node, then set the progress to 100
task['current_progress'] = 100
else:
# If the node is not the websocket node, then set the progress based on the node's progress
task['current_progress'] = (data['value'] / data['max']) * 100
if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']:
# In case the progress is 100 but the node is not the websocket node, set progress to 95
task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event
if data['value'] == 1:
# If the value is 1, send started to API
start_data = {
"ws_id": task['ws_id'],
"status": "started",
"info": None,
}
task["task_status"] = "executing"
await self.send_to_api(start_data)
elif event == "status":
self.logger.info(f"[memedeck]: sending status event: {data}")
# Update the task in tasks_by_ws_id
self.tasks_by_ws_id[sid] = task
async def send_preview(self, image_data, sid=None, progress=None, context=None):
if sid is None:
self.logger.warning("Received preview without sid")
return
task = self.tasks_by_ws_id.get(sid)
if not task:
self.logger.warning(f"Received preview for unknown sid: {sid}")
return
if progress is None:
progress = task['current_progress']
# if progress is odd, then don't send the preview
if int(progress) % 2 == 1:
return
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1)
preview_bytes = bytesIO.getvalue()
ai_queue_progress = {
"ws_id": sid,
"kind": "image_generating" if progress < 100 else "image_generated",
"data": list(preview_bytes),
"progress": int(progress),
"context": context
}
await self.send_to_api(ai_queue_progress)
if progress == 100:
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
self.logger.info(f"[memedeck]: Task {sid} completed")
async def send_to_api(self, data):
ws_id = data.get('ws_id')
if not ws_id:
self.logger.error("[memedeck]: Missing ws_id in data")
return
task = self.tasks_by_ws_id.get(ws_id)
if not task:
self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}")
return
if task['websocket_node_id'] is None:
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
return
try:
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
await self.loop.run_in_executor(None, post_func)
except Exception as e:
self.logger.error(f"[memedeck]: error sending to api: {e}")
# --------------------------------------------------------------------------
# MemedeckAzureStorage
# --------------------------------------------------------------------------
# from azure.storage.blob.aio import BlobClient, BlobServiceClient
# from azure.storage.blob import ContentSettings
# from typing import Optional, Tuple
# import cairosvg
# WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
# WATERMARK_SIZE = 40
# class MemedeckAzureStorage:
# def __init__(self, connection_string):
# # get environment variables
# self.storage_account = os.getenv('STORAGE_ACCOUNT')
# self.storage_access_key = os.getenv('STORAGE_ACCESS_KEY')
# self.storage_container = os.getenv('STORAGE_CONTAINER')
# self.logger = logging.getLogger(__name__)
# self.blob_service_client = BlobServiceClient.from_connection_string(conn_str=connection_string)
# async def upload_image(
# self,
# by: str,
# image_id: str,
# source_url: Optional[str],
# bytes_data: Optional[bytes],
# filetype: Optional[str],
# ) -> Tuple[str, Tuple[int, int]]:
# """
# Uploads an image to Azure Blob Storage.
# Args:
# by (str): Identifier for the uploader.
# image_id (str): Unique identifier for the image.
# source_url (Optional[str]): URL to fetch the image from.
# bytes_data (Optional[bytes]): Image data in bytes.
# filetype (Optional[str]): Desired file type (e.g., 'jpeg', 'png').
# Returns:
# Tuple[str, Tuple[int, int]]: URL of the uploaded image and its dimensions.
# """
# # Retrieve image bytes either from the provided bytes_data or by fetching from source_url
# if source_url is None:
# if bytes_data is None:
# raise ValueError("Could not get image bytes")
# image_bytes = bytes_data
# else:
# self.logger.info(f"Requesting image from URL: {source_url}")
# async with aiohttp.ClientSession() as session:
# try:
# async with session.get(source_url) as response:
# if response.status != 200:
# raise Exception(f"Failed to fetch image, status code {response.status}")
# image_bytes = await response.read()
# except Exception as e:
# raise Exception(f"Error fetching image from URL: {e}")
# # Open image using Pillow to get dimensions and format
# try:
# img = Image.open(BytesIO(image_bytes))
# width, height = img.size
# inferred_filetype = img.format.lower()
# except Exception as e:
# raise Exception(f"Failed to decode image: {e}")
# # Determine the final file type
# final_filetype = filetype.lower() if filetype else inferred_filetype
# # Construct the blob name
# blob_name = f"{by}/{image_id.replace('image:', '')}.{final_filetype}"
# # Upload the image to Azure Blob Storage
# try:
# image_url = await self.save_image(blob_name, img.format, image_bytes)
# return image_url, (width, height)
# except Exception as e:
# self.logger.error(f"Trouble saving image: {e}")
# raise Exception(f"Trouble saving image: {e}")
# async def save_image(
# self,
# blob_name: str,
# content_type: str,
# bytes_data: bytes
# ) -> str:
# """
# Saves image bytes to Azure Blob Storage.
# Args:
# blob_name (str): Name of the blob in Azure Storage.
# content_type (str): MIME type of the content.
# bytes_data (bytes): Image data in bytes.
# Returns:
# str: URL of the uploaded blob.
# """
# # Retrieve environment variables
# account = os.getenv("STORAGE_ACCOUNT")
# access_key = os.getenv("STORAGE_ACCESS_KEY")
# container = os.getenv("STORAGE_CONTAINER")
# if not all([account, access_key, container]):
# raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
# # Initialize BlobServiceClient
# blob_service_client = BlobServiceClient(
# account_url=f"https://{account}.blob.core.windows.net",
# credential=access_key
# )
# blob_client = blob_service_client.get_blob_client(container=container, blob=blob_name)
# # Upload the blob
# try:
# await blob_client.upload_blob(
# bytes_data,
# overwrite=True,
# content_settings=ContentSettings(content_type=content_type)
# )
# except Exception as e:
# raise Exception(f"Failed to upload blob: {e}")
# self.logger.debug(f"Blob uploaded: name={blob_name}, content_type={content_type}")
# # Construct and return the blob URL
# blob_url = f"https://media.memedeck.xyz//{container}/{blob_name}"
# return blob_url
# async def add_watermark(
# self,
# base_blob_name: str,
# base_image: bytes
# ) -> str:
# """
# Adds a watermark to the provided image and uploads the watermarked image.
# Args:
# base_blob_name (str): Original blob name of the image.
# base_image (bytes): Image data in bytes.
# Returns:
# str: URL of the watermarked image.
# """
# # Load the input image
# try:
# img = Image.open(BytesIO(base_image)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to load image: {e}")
# # Calculate position for the watermark (bottom right corner with padding)
# padding = 12
# x = img.width - WATERMARK_SIZE - padding
# y = img.height - WATERMARK_SIZE - padding
# # Analyze background brightness where the watermark will be placed
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
# self.logger.info(f"Background brightness: {background_brightness}")
# # Render SVG watermark to PNG bytes using cairosvg
# try:
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to render watermark SVG: {e}")
# # Determine watermark color based on background brightness
# if background_brightness > 128:
# # Dark watermark for light backgrounds
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
# else:
# # Light watermark for dark backgrounds
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
# # Apply the watermark color by blending
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
# watermark = Image.alpha_composite(watermark, solid_color)
# # Overlay the watermark onto the original image
# img.paste(watermark, (x, y), watermark)
# # Save the watermarked image to bytes
# buffer = BytesIO()
# img = img.convert("RGB") # Convert back to RGB for JPEG format
# img.save(buffer, format="JPEG")
# buffer.seek(0)
# jpeg_bytes = buffer.read()
# # Modify the blob name to include '_watermarked'
# try:
# if "memes/" in base_blob_name:
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
# else:
# base_blob_name_right = base_blob_name
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
# base_blob_name_without_extension = base_blob_name_split[0]
# extension = base_blob_name_split[1]
# except Exception as e:
# raise Exception(f"Failed to process blob name: {e}")
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
# # Upload the watermarked image
# try:
# watermarked_blob_url = await self.save_image(
# watermarked_blob_name,
# "image/jpeg",
# jpeg_bytes
# )
# return watermarked_blob_url
# except Exception as e:
# raise Exception(f"Failed to upload watermarked image: {e}")
# def analyze_background_brightness(
# self,
# img: Image.Image,
# x: int,
# y: int,
# size: int
# ) -> int:
# """
# Analyzes the brightness of a specific region in the image.
# Args:
# img (Image.Image): The image to analyze.
# x (int): X-coordinate of the top-left corner of the region.
# y (int): Y-coordinate of the top-left corner of the region.
# size (int): Size of the square region to analyze.
# Returns:
# int: Average brightness (0-255) of the region.
# """
# # Crop the specified region
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
# # Calculate average brightness using the luminance formula
# total_brightness = 0
# pixel_count = 0
# for pixel in sub_image.getdata():
# r, g, b = pixel
# brightness = (r * 299 + g * 587 + b * 114) // 1000
# total_brightness += brightness
# pixel_count += 1
# if pixel_count == 0:
# return 0
# average_brightness = total_brightness // pixel_count
# return average_brightness

View File

@ -1,73 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: "hidden"
layer_idx: -2

View File

@ -1,70 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -1,73 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: "hidden"
layer_idx: -2

View File

@ -1,74 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: "hidden"
layer_idx: -2

View File

@ -1,71 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -1,71 +0,0 @@
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -1,68 +0,0 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -1,68 +0,0 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: False
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -1,67 +0,0 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -1,67 +0,0 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: False
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -1,158 +0,0 @@
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

File diff suppressed because one or more lines are too long

View File

@ -22,3 +22,10 @@ kornia>=0.7.1
spandrel
soundfile
av
# memedeck dependencies
pika
python-dotenv
pillow
azure-storage-blob
cairosvg

3
styles/default.csv Normal file
View File

@ -0,0 +1,3 @@
name,prompt,negative_prompt
❌Low Token,,"embedding:EasyNegative, NSFW, Cleavage, Pubic Hair, Nudity, Naked, censored"
✅Line Art / Manga,"(Anime Scene, Toonshading, Satoshi Kon, Ken Sugimori, Hiromu Arakawa:1.2), (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows, glow effects, humorous illustration, big depth of field, Masterpiece, colors, concept art, trending on artstation, Vivid colors, dramatic",
1 name prompt negative_prompt
2 ❌Low Token embedding:EasyNegative, NSFW, Cleavage, Pubic Hair, Nudity, Naked, censored
3 ✅Line Art / Manga (Anime Scene, Toonshading, Satoshi Kon, Ken Sugimori, Hiromu Arakawa:1.2), (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows, glow effects, humorous illustration, big depth of field, Masterpiece, colors, concept art, trending on artstation, Vivid colors, dramatic