mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
added MemedeckWorker logic to the repo
This commit is contained in:
parent
f86c724ef2
commit
f93fe6b3fc
9
.gitignore
vendored
9
.gitignore
vendored
@ -21,3 +21,12 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
|
||||
.env
|
||||
|
||||
|
||||
models
|
||||
custom_nodes
|
||||
models/
|
||||
flux_loras/
|
||||
custom_nodes/
|
50
MEMEDECK.md
Normal file
50
MEMEDECK.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Commands to set up the environment.
|
||||
|
||||
## Enable MIGs on NVIDA GPU
|
||||
|
||||
<!-- link to the guide -->
|
||||
https://www.seimaxim.com/kb/gpu/nvidia-a100-mig-cheat-sheat
|
||||
|
||||
```zsh
|
||||
sudo nvidia-smi -i 0 -mig 1
|
||||
sudo nvidia-smi mig -cgi 9,9 -C
|
||||
```
|
||||
|
||||
Start comfy with MIGs
|
||||
|
||||
```zsh
|
||||
CUDA_VISIBLE_DEVICES=MIG-0 python main.py --port 5000 --listen 0.0.0.0 --cuda-device 0 --preview-method auto
|
||||
```
|
||||
|
||||
|
||||
## Tunnel remote server port to local machine port
|
||||
|
||||
### On the server
|
||||
|
||||
```zsh
|
||||
sudo nano /etc/ssh/sshd_config
|
||||
|
||||
# uncomment this line in the sshd_config file
|
||||
GatewayPorts yes
|
||||
```
|
||||
Then restart the machine.
|
||||
|
||||
## On your local machine
|
||||
```zsh
|
||||
sudo ssh -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40
|
||||
```
|
||||
|
||||
## install autossh (maybe not needed)
|
||||
```zsh
|
||||
brew install autossh
|
||||
autossh -M 0 -o "ServerAliveInterval 30" -o "ServerAliveCountMax 3" -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40
|
||||
```
|
||||
<!-- sudo ssh -i ~/.ssh/memedeck-monolith.pem -N -R 9090:localhost:8079 holium@172.206.15.40 -->
|
||||
|
||||
|
||||
## On the server
|
||||
This is to allow the port to be accessed from the local machine.
|
||||
|
||||
```zsh
|
||||
sudo iptables -A INPUT -p tcp --dport 9090 -j ACCEPT
|
||||
```
|
@ -213,10 +213,9 @@ def load_lora(lora, to_load, log_missing=True):
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
# for x in lora.keys():
|
||||
# if x not in loaded_keys:
|
||||
# logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
|
@ -1,155 +0,0 @@
|
||||
class Example:
|
||||
"""
|
||||
A example node
|
||||
|
||||
Class methods
|
||||
-------------
|
||||
INPUT_TYPES (dict):
|
||||
Tell the main program input parameters of nodes.
|
||||
IS_CHANGED:
|
||||
optional method to control when the node is re executed.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
RETURN_TYPES (`tuple`):
|
||||
The type of each element in the output tuple.
|
||||
RETURN_NAMES (`tuple`):
|
||||
Optional: The name of each output in the output tuple.
|
||||
FUNCTION (`str`):
|
||||
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
|
||||
OUTPUT_NODE ([`bool`]):
|
||||
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
|
||||
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
|
||||
Assumed to be False if not present.
|
||||
CATEGORY (`str`):
|
||||
The category the node should appear in the UI.
|
||||
DEPRECATED (`bool`):
|
||||
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
|
||||
functional in existing workflows that use them.
|
||||
EXPERIMENTAL (`bool`):
|
||||
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
|
||||
significant changes or removal in future versions. Use with caution in production workflows.
|
||||
execute(s) -> tuple || None:
|
||||
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
||||
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
"""
|
||||
Return a dictionary which contains config for all input fields.
|
||||
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
|
||||
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
|
||||
The type can be a list for selection.
|
||||
|
||||
Returns: `dict`:
|
||||
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
|
||||
- Value input_fields (`dict`): Contains input fields config:
|
||||
* Key field_name (`string`): Name of a entry-point method's argument
|
||||
* Value field_config (`tuple`):
|
||||
+ First value is a string indicate the type of field or a list for selection.
|
||||
+ Second value is a config for type "INT", "STRING" or "FLOAT".
|
||||
"""
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"int_field": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0, #Minimum value
|
||||
"max": 4096, #Maximum value
|
||||
"step": 64, #Slider's step
|
||||
"display": "number", # Cosmetic only: display as "number" or "slider"
|
||||
"lazy": True # Will only be evaluated if check_lazy_status requires it
|
||||
}),
|
||||
"float_field": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 10.0,
|
||||
"step": 0.01,
|
||||
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||
"display": "number",
|
||||
"lazy": True
|
||||
}),
|
||||
"print_to_screen": (["enable", "disable"],),
|
||||
"string_field": ("STRING", {
|
||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||
"default": "Hello World!",
|
||||
"lazy": True
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
#RETURN_NAMES = ("image_output_name",)
|
||||
|
||||
FUNCTION = "test"
|
||||
|
||||
#OUTPUT_NODE = False
|
||||
|
||||
CATEGORY = "Example"
|
||||
|
||||
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
"""
|
||||
Return a list of input names that need to be evaluated.
|
||||
|
||||
This function will be called if there are any lazy inputs which have not yet been
|
||||
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||
(and more exist), this function will be called again once the value of the requested
|
||||
field is available.
|
||||
|
||||
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||
inputs will have the value None.
|
||||
"""
|
||||
if print_to_screen == "enable":
|
||||
return ["int_field", "float_field", "string_field"]
|
||||
else:
|
||||
return []
|
||||
|
||||
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
if print_to_screen == "enable":
|
||||
print(f"""Your input contains:
|
||||
string_field aka input text: {string_field}
|
||||
int_field: {int_field}
|
||||
float_field: {float_field}
|
||||
""")
|
||||
#do some processing on the image, in this example I just invert it
|
||||
image = 1.0 - image
|
||||
return (image,)
|
||||
|
||||
"""
|
||||
The node will always be re executed if any of the inputs change but
|
||||
this method can be used to force the node to execute again even when the inputs don't change.
|
||||
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
|
||||
executed, if it is different the node will be executed again.
|
||||
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
|
||||
changes between executions the LoadImage node is executed again.
|
||||
"""
|
||||
#@classmethod
|
||||
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
|
||||
# return ""
|
||||
|
||||
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
|
||||
# WEB_DIRECTORY = "./somejs"
|
||||
|
||||
|
||||
# Add custom API routes, using router
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
|
||||
@PromptServer.instance.routes.get("/hello")
|
||||
async def get_hello(request):
|
||||
return web.json_response("hello")
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Example": Example
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Example": "Example Node"
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import comfy.utils
|
||||
import time
|
||||
|
||||
#You can use this node to save full size images through the websocket, the
|
||||
#images will be sent in exactly the same format as the image previews: as
|
||||
#binary images on the websocket with a 8 byte header indicating the type
|
||||
#of binary message (first 4 bytes) and the image format (next 4 bytes).
|
||||
|
||||
#Note that no metadata will be put in the images saved with this node.
|
||||
|
||||
class SaveImageWebsocket:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "api/image"
|
||||
|
||||
def save_images(self, images):
|
||||
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||
step = 0
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
|
||||
step += 1
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, images):
|
||||
return time.time()
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveImageWebsocket": SaveImageWebsocket,
|
||||
}
|
46
execution.py
46
execution.py
@ -245,7 +245,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
def execute(server, memedeck_worker, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -253,10 +253,15 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
is_memedeck = extra_data.get("is_memedeck", False)
|
||||
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
if memedeck_worker.ws_id is not None:
|
||||
memedeck_worker.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, memedeck_worker.ws_id)
|
||||
# elif is_memedeck:
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
@ -284,10 +289,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
if server.client_id is not None and not is_memedeck:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
elif is_memedeck:
|
||||
memedeck_worker.last_node_id = display_node_id
|
||||
memedeck_worker.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, memedeck_worker.ws_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
@ -319,6 +327,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
"current_outputs": [],
|
||||
}
|
||||
server.send_sync("execution_error", mes, server.client_id)
|
||||
memedeck_worker.send_sync("execution_error", mes, memedeck_worker.ws_id)
|
||||
return ExecutionBlocker(None)
|
||||
else:
|
||||
return block
|
||||
@ -335,8 +344,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
},
|
||||
"output": output_ui
|
||||
})
|
||||
if server.client_id is not None:
|
||||
if server.client_id is not None and not is_memedeck:
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
elif is_memedeck:
|
||||
memedeck_worker.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, memedeck_worker.ws_id)
|
||||
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
@ -414,9 +426,12 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, lru_size=None):
|
||||
def __init__(self, server, memedeck_worker, lru_size=None):
|
||||
self.lru_size = lru_size
|
||||
self.server = server
|
||||
self.memedeck_worker = memedeck_worker
|
||||
# set logging level
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@ -432,6 +447,9 @@ class PromptExecutor:
|
||||
self.status_messages.append((event, data))
|
||||
if self.server.client_id is not None or broadcast:
|
||||
self.server.send_sync(event, data, self.server.client_id)
|
||||
elif self.memedeck_worker.ws_id is not None:
|
||||
self.memedeck_worker.send_sync(event, data, self.memedeck_worker.ws_id)
|
||||
|
||||
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
node_id = error["node_id"]
|
||||
@ -466,8 +484,12 @@ class PromptExecutor:
|
||||
|
||||
if "client_id" in extra_data:
|
||||
self.server.client_id = extra_data["client_id"]
|
||||
if "ws_id" in extra_data:
|
||||
self.memedeck_worker.ws_id = extra_data["ws_id"]
|
||||
self.memedeck_worker.websocket_node_id = extra_data["websocket_node_id"]
|
||||
else:
|
||||
self.server.client_id = None
|
||||
self.memedeck_worker.ws_id = None
|
||||
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
@ -501,7 +523,7 @@ class PromptExecutor:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
result, error, ex = execute(self.server, self.memedeck_worker, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@ -512,7 +534,7 @@ class PromptExecutor:
|
||||
execution_list.complete_node_execution()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id, 'ws_id': self.memedeck_worker.ws_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
@ -527,6 +549,7 @@ class PromptExecutor:
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
self.memedeck_worker.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
@ -869,8 +892,9 @@ def validate_prompt(prompt):
|
||||
MAXIMUM_HISTORY_SIZE = 10000
|
||||
|
||||
class PromptQueue:
|
||||
def __init__(self, server):
|
||||
def __init__(self, server, memedeck_worker):
|
||||
self.server = server
|
||||
self.memedeck_worker = memedeck_worker
|
||||
self.mutex = threading.RLock()
|
||||
self.not_empty = threading.Condition(self.mutex)
|
||||
self.task_counter = 0
|
||||
@ -884,6 +908,7 @@ class PromptQueue:
|
||||
with self.mutex:
|
||||
heapq.heappush(self.queue, item)
|
||||
self.server.queue_updated()
|
||||
self.memedeck_worker.queue_updated()
|
||||
self.not_empty.notify()
|
||||
|
||||
def get(self, timeout=None):
|
||||
@ -897,6 +922,7 @@ class PromptQueue:
|
||||
self.currently_running[i] = copy.deepcopy(item)
|
||||
self.task_counter += 1
|
||||
self.server.queue_updated()
|
||||
self.memedeck_worker.queue_updated()
|
||||
return (item, i)
|
||||
|
||||
class ExecutionStatus(NamedTuple):
|
||||
@ -922,6 +948,7 @@ class PromptQueue:
|
||||
}
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
self.memedeck_worker.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
with self.mutex:
|
||||
@ -938,6 +965,7 @@ class PromptQueue:
|
||||
with self.mutex:
|
||||
self.queue = []
|
||||
self.server.queue_updated()
|
||||
self.memedeck_worker.queue_updated()
|
||||
|
||||
def delete_queue_item(self, function):
|
||||
with self.mutex:
|
||||
@ -949,6 +977,8 @@ class PromptQueue:
|
||||
self.queue.pop(x)
|
||||
heapq.heapify(self.queue)
|
||||
self.server.queue_updated()
|
||||
self.memedeck_worker.queue_updated()
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
90
main.py
90
main.py
@ -53,6 +53,14 @@ def apply_custom_paths():
|
||||
logging.info(f"Setting user directory to: {user_dir}")
|
||||
folder_paths.set_user_directory(user_dir)
|
||||
|
||||
# ---------------------------------------------------------------------------------------
|
||||
# memedeck imports
|
||||
# ---------------------------------------------------------------------------------------
|
||||
from memedeck import MemedeckWorker
|
||||
|
||||
import sys
|
||||
sys.stdout = open(os.devnull, 'w') # disable all print statements
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
def execute_prestartup_script():
|
||||
def execute_script(script_path):
|
||||
@ -153,9 +161,11 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
def prompt_worker(q, server, memedeck_worker):
|
||||
current_time: float = 0.0
|
||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||
e = execution.PromptExecutor(server, memedeck_worker, lru_size=args.cache_lru)
|
||||
|
||||
# threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@ -209,12 +219,12 @@ def prompt_worker(q, server_instance):
|
||||
need_gc = False
|
||||
|
||||
|
||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||
async def run(server_instance, memedeck_worker, address='', port=8188, verbose=True, call_on_start=None):
|
||||
addresses = []
|
||||
for addr in address.split(","):
|
||||
addresses.append((addr, port))
|
||||
await asyncio.gather(
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop(), memedeck_worker.publish_loop()
|
||||
)
|
||||
|
||||
|
||||
@ -224,11 +234,34 @@ def hijack_progress(server_instance):
|
||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if memedeck_worker.ws_id is not None:
|
||||
memedeck_worker.send_sync("progress", {"value": value, "max": total, "prompt_id": memedeck_worker.last_prompt_id, "node": server_instance.last_node_id}, memedeck_worker.ws_id)
|
||||
if preview_image is not None:
|
||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||
|
||||
if memedeck_worker.ws_id is not None:
|
||||
memedeck_worker.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, memedeck_worker.ws_id)
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
# async def run(server, memedeck_worker, address='', port=8188, verbose=True, call_on_start=None):
|
||||
# addresses = []
|
||||
# for addr in address.split(","):
|
||||
# addresses.append((addr, port))
|
||||
# # add memedeck worker publish loop
|
||||
# await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop(), memedeck_worker.publish_loop())
|
||||
|
||||
|
||||
# def hijack_progress(server, memedeck_worker):
|
||||
# def hook(value, total, preview_image):
|
||||
# comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
# server.send_sync("progress", {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}, server.client_id)
|
||||
# if memedeck_worker.ws_id is not None:
|
||||
# memedeck_worker.send_sync("progress", {"value": value, "max": total, "prompt_id": memedeck_worker.last_prompt_id, "node": server.last_node_id}, memedeck_worker.ws_id)
|
||||
# if preview_image is not None:
|
||||
# server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||
# if memedeck_worker.ws_id is not None:
|
||||
# memedeck_worker.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, memedeck_worker.ws_id)
|
||||
# comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
def cleanup_temp():
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
@ -258,16 +291,54 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
memedeck_worker = MemedeckWorker(asyncio_loop)
|
||||
q = execution.PromptQueue(prompt_server, memedeck_worker)
|
||||
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# server = server.PromptServer(loop)
|
||||
# q = execution.PromptQueue(server, memedeck_worker)
|
||||
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
utils.extra_config.load_extra_path_config(config_path)
|
||||
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
hijack_progress(server, memedeck_worker)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server, memedeck_worker)).start()
|
||||
threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
|
||||
# set logging level to info
|
||||
|
||||
if args.output_directory:
|
||||
output_dir = os.path.abspath(args.output_directory)
|
||||
logging.info(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
logging.info(f"Setting input directory to: {input_dir}")
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.user_directory:
|
||||
user_dir = os.path.abspath(args.user_directory)
|
||||
logging.info(f"Setting user directory to: {user_dir}")
|
||||
folder_paths.set_user_directory(user_dir)
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
@ -287,6 +358,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
async def start_all():
|
||||
await prompt_server.setup()
|
||||
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||
await run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||
|
||||
# Returning these so that other code can integrate with the ComfyUI loop and server
|
||||
return asyncio_loop, prompt_server, start_all
|
||||
@ -298,6 +370,8 @@ if __name__ == "__main__":
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
event_loop.run_until_complete(start_all_func())
|
||||
# event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
||||
|
266
memedeck-v1.py
Normal file
266
memedeck-v1.py
Normal file
@ -0,0 +1,266 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import os
|
||||
import logging
|
||||
import signal
|
||||
import struct
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from PIL import Image, ImageOps
|
||||
from functools import partial
|
||||
|
||||
import pika
|
||||
import json
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
|
||||
# load from env file
|
||||
# load from .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev'
|
||||
|
||||
# define the enum in python
|
||||
from enum import Enum
|
||||
|
||||
class QueueProgressKind(Enum):
|
||||
# make json serializable
|
||||
ImageGenerated = "image_generated"
|
||||
ImageGenerating = "image_generating"
|
||||
SamePrompt = "same_prompt"
|
||||
FaceswapGenerated = "faceswap_generated"
|
||||
FaceswapGenerating = "faceswap_generating"
|
||||
Failed = "failed"
|
||||
|
||||
class MemedeckWorker:
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
|
||||
class JsonEventTypes(Enum):
|
||||
PROGRESS = "progress"
|
||||
EXECUTING = "executing"
|
||||
EXECUTED = "executed"
|
||||
ERROR = "error"
|
||||
STATUS = "status"
|
||||
|
||||
"""
|
||||
MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api
|
||||
it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api
|
||||
"""
|
||||
def __init__(self, loop):
|
||||
MemedeckWorker.instance = self
|
||||
# set logging level to info
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
self.active_tasks_map = {}
|
||||
self.current_task = None
|
||||
|
||||
self.client_id = None
|
||||
self.ws_id = None
|
||||
self.websocket_node_id = None
|
||||
self.current_node = None
|
||||
self.current_progress = 0
|
||||
self.current_context = None
|
||||
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
|
||||
self.http_client = None
|
||||
self.prompt_queue = None
|
||||
self.validate_prompt = None
|
||||
self.last_prompt_id = None
|
||||
|
||||
self.amqp_url = amqp_addr
|
||||
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
|
||||
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
|
||||
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n")
|
||||
|
||||
def on_connection_open(self, connection):
|
||||
self.connection = connection
|
||||
self.connection.channel(on_open_callback=self.on_channel_open)
|
||||
|
||||
def on_channel_open(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
# only consume one message at a time
|
||||
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
|
||||
self.channel.queue_declare(queue=self.queue_name, durable=True)
|
||||
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received)
|
||||
|
||||
def start(self, prompt_queue, validate_prompt):
|
||||
self.prompt_queue = prompt_queue
|
||||
self.validate_prompt = validate_prompt
|
||||
|
||||
parameters = pika.URLParameters(self.amqp_url)
|
||||
logging.getLogger('pika').setLevel(logging.WARNING) # supress all logs from pika
|
||||
self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open)
|
||||
|
||||
try:
|
||||
self.connection.ioloop.start()
|
||||
except KeyboardInterrupt:
|
||||
self.connection.close()
|
||||
self.connection.ioloop.start()
|
||||
|
||||
def on_message_received(self, channel, method, properties, body):
|
||||
decoded_string = body.decode('utf-8')
|
||||
json_object = json.loads(decoded_string)
|
||||
payload = json_object[1]
|
||||
|
||||
# execute the task
|
||||
prompt = payload["nodes"]
|
||||
valid = self.validate_prompt(prompt)
|
||||
|
||||
self.current_node = None
|
||||
self.current_progress = 0
|
||||
self.websocket_node_id = None
|
||||
self.ws_id = payload["source_ws_id"]
|
||||
self.current_context = payload["req_ctx"]
|
||||
|
||||
for node in prompt: # search through prompt nodes for websocket_node_id
|
||||
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
|
||||
self.websocket_node_id = node
|
||||
break
|
||||
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.active_tasks_map[payload["source_ws_id"]] = {
|
||||
"prompt_id": prompt_id,
|
||||
"prompt": prompt,
|
||||
"outputs_to_execute": outputs_to_execute,
|
||||
"client_id": "memedeck-1",
|
||||
"is_memedeck": True,
|
||||
"websocket_node_id": self.websocket_node_id,
|
||||
"ws_id": payload["source_ws_id"],
|
||||
"context": payload["req_ctx"],
|
||||
"current_node": None,
|
||||
"current_progress": 0,
|
||||
}
|
||||
self.prompt_queue.put((0, prompt_id, prompt, {
|
||||
"client_id": "memedeck-1",
|
||||
'is_memedeck': True,
|
||||
'websocket_node_id': self.websocket_node_id,
|
||||
'ws_id': payload["source_ws_id"],
|
||||
'context': payload["req_ctx"]
|
||||
}, outputs_to_execute))
|
||||
self.set_last_prompt_id(prompt_id)
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag) # ack the task
|
||||
else:
|
||||
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
|
||||
|
||||
# --------------------------------------------------
|
||||
# callbacks for the prompt queue
|
||||
# --------------------------------------------------
|
||||
def queue_updated(self):
|
||||
# print json of the queue info but only print the first 100 lines
|
||||
info = self.get_queue_info()
|
||||
# update_type = info['']
|
||||
# self.send_sync("status", { "status": self.get_queue_info() })
|
||||
|
||||
def get_queue_info(self):
|
||||
prompt_info = {}
|
||||
exec_info = {}
|
||||
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
|
||||
prompt_info['exec_info'] = exec_info
|
||||
return prompt_info
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
def set_last_prompt_id(self, prompt_id):
|
||||
self.last_prompt_id = prompt_id
|
||||
|
||||
async def publish_loop(self):
|
||||
while True:
|
||||
msg = await self.messages.get()
|
||||
await self.send(*msg)
|
||||
|
||||
async def send(self, event, data, sid=None):
|
||||
current_task = self.active_tasks_map.get(sid)
|
||||
if current_task is None or current_task['ws_id'] != sid:
|
||||
return
|
||||
|
||||
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: # preview and unencoded images are sent here
|
||||
self.logger.info(f"[memedeck]: sending image preview for {sid}")
|
||||
await self.send_preview(data, sid=current_task['ws_id'], progress=current_task['current_progress'], context=current_task['context'])
|
||||
else: # send json data / text data
|
||||
if event == "executing":
|
||||
current_task['current_node'] = data['node']
|
||||
elif event == "executed":
|
||||
self.logger.info(f"---> [memedeck]: executed event for {sid}")
|
||||
prompt_id = data['prompt_id']
|
||||
if prompt_id in self.active_tasks_map:
|
||||
del self.active_tasks_map[prompt_id]
|
||||
elif event == "progress":
|
||||
if current_task['current_node'] == current_task['websocket_node_id']: # if the node is the websocket node, then set the progress to 100
|
||||
current_task['current_progress'] = 100
|
||||
else: # if the node is not the websocket node, then set the progress to the progress from the node
|
||||
current_task['current_progress'] = data['value'] / data['max'] * 100
|
||||
if current_task['current_progress'] == 100 and current_task['current_node'] != current_task['websocket_node_id']:
|
||||
# in case the progress is 100 but the node is not the websocket node, then set the progress to 95
|
||||
current_task['current_progress'] = 95 # this allows the full resolution image to be sent on the 100 progress event
|
||||
|
||||
if data['value'] == 1: # if the value is 1, then send started to api
|
||||
start_data = {
|
||||
"ws_id": current_task['ws_id'],
|
||||
"status": "started",
|
||||
"info": None,
|
||||
}
|
||||
await self.send_to_api(start_data)
|
||||
|
||||
elif event == "status":
|
||||
self.logger.info(f"[memedeck]: sending status event: {data}")
|
||||
|
||||
self.active_tasks_map[sid] = current_task
|
||||
|
||||
|
||||
async def send_preview(self, image_data, sid=None, progress=None, context=None):
|
||||
# if self.current_progress is odd, then don't send the preview
|
||||
if progress % 2 == 1:
|
||||
return
|
||||
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.ANTIALIAS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
|
||||
bytesIO = BytesIO()
|
||||
image.save(bytesIO, format=image_type, quality=100 if progress == 96 else 75, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
|
||||
ai_queue_progress = {
|
||||
"ws_id": sid,
|
||||
"kind": "image_generating" if progress < 100 else "image_generated",
|
||||
"data": list(preview_bytes),
|
||||
"progress": int(progress),
|
||||
"context": context
|
||||
}
|
||||
|
||||
await self.send_to_api(ai_queue_progress)
|
||||
|
||||
async def send_to_api(self, data):
|
||||
if self.websocket_node_id is None: # check if the node is still running
|
||||
logging.error(f"[memedeck]: websocket_node_id is None for {data['ws_id']}")
|
||||
return
|
||||
|
||||
try:
|
||||
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
|
||||
await self.loop.run_in_executor(None, post_func)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[memedeck]: error sending to api: {e}")
|
568
memedeck.py
Normal file
568
memedeck.py
Normal file
@ -0,0 +1,568 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import os
|
||||
import logging
|
||||
import signal
|
||||
import struct
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from PIL import Image, ImageOps
|
||||
from functools import partial
|
||||
|
||||
import pika
|
||||
import json
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
amqp_addr = os.getenv('AMQP_ADDR') or 'amqp://api:gacdownatravKekmy9@51.8.120.154:5672/dev'
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class QueueProgressKind(Enum):
|
||||
ImageGenerated = "image_generated"
|
||||
ImageGenerating = "image_generating"
|
||||
SamePrompt = "same_prompt"
|
||||
FaceswapGenerated = "faceswap_generated"
|
||||
FaceswapGenerating = "faceswap_generating"
|
||||
Failed = "failed"
|
||||
|
||||
class MemedeckWorker:
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
|
||||
class JsonEventTypes(Enum):
|
||||
PROGRESS = "progress"
|
||||
EXECUTING = "executing"
|
||||
EXECUTED = "executed"
|
||||
ERROR = "error"
|
||||
STATUS = "status"
|
||||
|
||||
"""
|
||||
MemedeckWorker is a class that is responsible for relaying messages between comfy and the memedeck backend api
|
||||
it is used to send images to the memedeck backend api and to receive prompts from the memedeck backend api
|
||||
"""
|
||||
def __init__(self, loop):
|
||||
MemedeckWorker.instance = self
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
|
||||
self.http_client = None
|
||||
self.prompt_queue = None
|
||||
self.validate_prompt = None
|
||||
self.last_prompt_id = None
|
||||
|
||||
self.amqp_url = amqp_addr
|
||||
self.queue_name = os.getenv('QUEUE_NAME') or 'generic-queue'
|
||||
self.api_url = os.getenv('API_ADDRESS') or 'http://0.0.0.0:8079/v2'
|
||||
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
|
||||
|
||||
# Internal job queue
|
||||
self.internal_job_queue = asyncio.Queue()
|
||||
|
||||
# Dictionary to keep track of tasks by ws_id
|
||||
self.tasks_by_ws_id = {}
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.info(f"\n[memedeck]: initialized with API URL: {self.api_url} and API Key: {self.api_key}\n")
|
||||
|
||||
def on_connection_open(self, connection):
|
||||
self.connection = connection
|
||||
self.connection.channel(on_open_callback=self.on_channel_open)
|
||||
|
||||
def on_channel_open(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
# only consume one message at a time
|
||||
self.channel.basic_qos(prefetch_size=0, prefetch_count=1)
|
||||
self.channel.queue_declare(queue=self.queue_name, durable=True)
|
||||
self.channel.basic_consume(queue=self.queue_name, on_message_callback=self.on_message_received, auto_ack=False)
|
||||
|
||||
def start(self, prompt_queue, validate_prompt):
|
||||
self.prompt_queue = prompt_queue
|
||||
self.validate_prompt = validate_prompt
|
||||
|
||||
# Start the process_job_queue task **after** prompt_queue is set
|
||||
self.loop.create_task(self.process_job_queue())
|
||||
|
||||
parameters = pika.URLParameters(self.amqp_url)
|
||||
logging.getLogger('pika').setLevel(logging.WARNING) # suppress all logs from pika
|
||||
self.connection = pika.SelectConnection(parameters, on_open_callback=self.on_connection_open)
|
||||
|
||||
try:
|
||||
self.connection.ioloop.start()
|
||||
except KeyboardInterrupt:
|
||||
self.connection.close()
|
||||
self.connection.ioloop.start()
|
||||
|
||||
def on_message_received(self, channel, method, properties, body):
|
||||
decoded_string = body.decode('utf-8')
|
||||
json_object = json.loads(decoded_string)
|
||||
payload = json_object[1]
|
||||
|
||||
# execute the task
|
||||
prompt = payload["nodes"]
|
||||
valid = self.validate_prompt(prompt)
|
||||
|
||||
# Prepare task_info
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
task_info = {
|
||||
"prompt_id": prompt_id,
|
||||
"prompt": prompt,
|
||||
"outputs_to_execute": outputs_to_execute,
|
||||
"client_id": "memedeck-1",
|
||||
"is_memedeck": True,
|
||||
"websocket_node_id": None,
|
||||
"ws_id": payload["source_ws_id"],
|
||||
"context": payload["req_ctx"],
|
||||
"current_node": None,
|
||||
"current_progress": 0,
|
||||
"delivery_tag": method.delivery_tag,
|
||||
"task_status": "waiting",
|
||||
}
|
||||
|
||||
# Find the websocket_node_id
|
||||
for node in prompt:
|
||||
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
|
||||
task_info['websocket_node_id'] = node
|
||||
break
|
||||
|
||||
if valid[0]:
|
||||
# Enqueue the task into the internal job queue
|
||||
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
|
||||
self.logger.info(f"[memedeck]: Enqueued task for {task_info['ws_id']}")
|
||||
else:
|
||||
channel.basic_nack(delivery_tag=method.delivery_tag, requeue=False) # unack the message
|
||||
|
||||
async def process_job_queue(self):
|
||||
while True:
|
||||
prompt_id, prompt, task_info = await self.internal_job_queue.get()
|
||||
# Start a new coroutine for each task
|
||||
self.loop.create_task(self.process_task(prompt_id, prompt, task_info))
|
||||
|
||||
async def process_task(self, prompt_id, prompt, task_info):
|
||||
ws_id = task_info['ws_id']
|
||||
# Add the task to tasks_by_ws_id
|
||||
self.tasks_by_ws_id[ws_id] = task_info
|
||||
# Put the prompt into the prompt_queue
|
||||
self.prompt_queue.put((0, prompt_id, prompt, {
|
||||
"client_id": task_info["client_id"],
|
||||
'is_memedeck': task_info['is_memedeck'],
|
||||
'websocket_node_id': task_info['websocket_node_id'],
|
||||
'ws_id': task_info['ws_id'],
|
||||
'context': task_info['context']
|
||||
}, task_info['outputs_to_execute']))
|
||||
# Acknowledge the message
|
||||
self.channel.basic_ack(delivery_tag=task_info["delivery_tag"]) # ack the task
|
||||
self.logger.info(f"[memedeck]: Acked task {prompt_id} {ws_id}")
|
||||
|
||||
self.logger.info(f"[memedeck]: Started processing prompt {prompt_id}")
|
||||
# Wait until the current task is completed
|
||||
await self.wait_for_task_completion(ws_id)
|
||||
# Task is done
|
||||
self.internal_job_queue.task_done()
|
||||
|
||||
async def wait_for_task_completion(self, ws_id):
|
||||
"""
|
||||
Wait until the task with the given ws_id is completed.
|
||||
"""
|
||||
while ws_id in self.tasks_by_ws_id:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# --------------------------------------------------
|
||||
# callbacks for the prompt queue
|
||||
# --------------------------------------------------
|
||||
def queue_updated(self):
|
||||
info = self.get_queue_info()
|
||||
# self.send_sync("status", { "status": self.get_queue_info() })
|
||||
|
||||
def get_queue_info(self):
|
||||
prompt_info = {}
|
||||
exec_info = {}
|
||||
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
|
||||
prompt_info['exec_info'] = exec_info
|
||||
return prompt_info
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
async def publish_loop(self):
|
||||
while True:
|
||||
msg = await self.messages.get()
|
||||
await self.send(*msg)
|
||||
|
||||
async def send(self, event, data, sid=None):
|
||||
if sid is None:
|
||||
self.logger.warning("Received event without sid")
|
||||
return
|
||||
|
||||
# Retrieve the task based on sid
|
||||
task = self.tasks_by_ws_id.get(sid)
|
||||
if not task:
|
||||
self.logger.warning(f"Received event {event} for unknown sid: {sid}")
|
||||
return
|
||||
|
||||
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_preview(
|
||||
data,
|
||||
sid=sid,
|
||||
progress=task['current_progress'],
|
||||
context=task['context']
|
||||
)
|
||||
else:
|
||||
# Send JSON data / text data
|
||||
if event == "executing":
|
||||
task['current_node'] = data['node']
|
||||
task["task_status"] = "executing"
|
||||
elif event == "progress":
|
||||
if task['current_node'] == task['websocket_node_id']:
|
||||
# If the node is the websocket node, then set the progress to 100
|
||||
task['current_progress'] = 100
|
||||
else:
|
||||
# If the node is not the websocket node, then set the progress based on the node's progress
|
||||
task['current_progress'] = (data['value'] / data['max']) * 100
|
||||
if task['current_progress'] == 100 and task['current_node'] != task['websocket_node_id']:
|
||||
# In case the progress is 100 but the node is not the websocket node, set progress to 95
|
||||
task['current_progress'] = 95 # Allows the full resolution image to be sent on the 100 progress event
|
||||
|
||||
if data['value'] == 1:
|
||||
# If the value is 1, send started to API
|
||||
start_data = {
|
||||
"ws_id": task['ws_id'],
|
||||
"status": "started",
|
||||
"info": None,
|
||||
}
|
||||
task["task_status"] = "executing"
|
||||
await self.send_to_api(start_data)
|
||||
|
||||
elif event == "status":
|
||||
self.logger.info(f"[memedeck]: sending status event: {data}")
|
||||
|
||||
# Update the task in tasks_by_ws_id
|
||||
self.tasks_by_ws_id[sid] = task
|
||||
|
||||
async def send_preview(self, image_data, sid=None, progress=None, context=None):
|
||||
if sid is None:
|
||||
self.logger.warning("Received preview without sid")
|
||||
return
|
||||
|
||||
task = self.tasks_by_ws_id.get(sid)
|
||||
if not task:
|
||||
self.logger.warning(f"Received preview for unknown sid: {sid}")
|
||||
return
|
||||
|
||||
if progress is None:
|
||||
progress = task['current_progress']
|
||||
|
||||
# if progress is odd, then don't send the preview
|
||||
if int(progress) % 2 == 1:
|
||||
return
|
||||
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.ANTIALIAS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
|
||||
bytesIO = BytesIO()
|
||||
image.save(bytesIO, format=image_type, quality=100 if progress == 95 else 75, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
|
||||
ai_queue_progress = {
|
||||
"ws_id": sid,
|
||||
"kind": "image_generating" if progress < 100 else "image_generated",
|
||||
"data": list(preview_bytes),
|
||||
"progress": int(progress),
|
||||
"context": context
|
||||
}
|
||||
|
||||
await self.send_to_api(ai_queue_progress)
|
||||
|
||||
if progress == 100:
|
||||
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
|
||||
self.logger.info(f"[memedeck]: Task {sid} completed")
|
||||
|
||||
async def send_to_api(self, data):
|
||||
ws_id = data.get('ws_id')
|
||||
if not ws_id:
|
||||
self.logger.error("[memedeck]: Missing ws_id in data")
|
||||
return
|
||||
task = self.tasks_by_ws_id.get(ws_id)
|
||||
if not task:
|
||||
self.logger.error(f"[memedeck]: No task found for ws_id {ws_id}")
|
||||
return
|
||||
if task['websocket_node_id'] is None:
|
||||
self.logger.error(f"[memedeck]: websocket_node_id is None for {ws_id}")
|
||||
return
|
||||
try:
|
||||
post_func = partial(requests.post, f"{self.api_url}/generation/update", json=data)
|
||||
await self.loop.run_in_executor(None, post_func)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[memedeck]: error sending to api: {e}")
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# MemedeckAzureStorage
|
||||
# --------------------------------------------------------------------------
|
||||
# from azure.storage.blob.aio import BlobClient, BlobServiceClient
|
||||
# from azure.storage.blob import ContentSettings
|
||||
# from typing import Optional, Tuple
|
||||
# import cairosvg
|
||||
|
||||
# WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
|
||||
# WATERMARK_SIZE = 40
|
||||
|
||||
# class MemedeckAzureStorage:
|
||||
# def __init__(self, connection_string):
|
||||
# # get environment variables
|
||||
# self.storage_account = os.getenv('STORAGE_ACCOUNT')
|
||||
# self.storage_access_key = os.getenv('STORAGE_ACCESS_KEY')
|
||||
# self.storage_container = os.getenv('STORAGE_CONTAINER')
|
||||
# self.logger = logging.getLogger(__name__)
|
||||
|
||||
# self.blob_service_client = BlobServiceClient.from_connection_string(conn_str=connection_string)
|
||||
|
||||
# async def upload_image(
|
||||
# self,
|
||||
# by: str,
|
||||
# image_id: str,
|
||||
# source_url: Optional[str],
|
||||
# bytes_data: Optional[bytes],
|
||||
# filetype: Optional[str],
|
||||
# ) -> Tuple[str, Tuple[int, int]]:
|
||||
# """
|
||||
# Uploads an image to Azure Blob Storage.
|
||||
|
||||
# Args:
|
||||
# by (str): Identifier for the uploader.
|
||||
# image_id (str): Unique identifier for the image.
|
||||
# source_url (Optional[str]): URL to fetch the image from.
|
||||
# bytes_data (Optional[bytes]): Image data in bytes.
|
||||
# filetype (Optional[str]): Desired file type (e.g., 'jpeg', 'png').
|
||||
|
||||
# Returns:
|
||||
# Tuple[str, Tuple[int, int]]: URL of the uploaded image and its dimensions.
|
||||
# """
|
||||
# # Retrieve image bytes either from the provided bytes_data or by fetching from source_url
|
||||
# if source_url is None:
|
||||
# if bytes_data is None:
|
||||
# raise ValueError("Could not get image bytes")
|
||||
# image_bytes = bytes_data
|
||||
# else:
|
||||
# self.logger.info(f"Requesting image from URL: {source_url}")
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# try:
|
||||
# async with session.get(source_url) as response:
|
||||
# if response.status != 200:
|
||||
# raise Exception(f"Failed to fetch image, status code {response.status}")
|
||||
# image_bytes = await response.read()
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Error fetching image from URL: {e}")
|
||||
|
||||
# # Open image using Pillow to get dimensions and format
|
||||
# try:
|
||||
# img = Image.open(BytesIO(image_bytes))
|
||||
# width, height = img.size
|
||||
# inferred_filetype = img.format.lower()
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to decode image: {e}")
|
||||
|
||||
# # Determine the final file type
|
||||
# final_filetype = filetype.lower() if filetype else inferred_filetype
|
||||
|
||||
# # Construct the blob name
|
||||
# blob_name = f"{by}/{image_id.replace('image:', '')}.{final_filetype}"
|
||||
|
||||
# # Upload the image to Azure Blob Storage
|
||||
# try:
|
||||
# image_url = await self.save_image(blob_name, img.format, image_bytes)
|
||||
# return image_url, (width, height)
|
||||
# except Exception as e:
|
||||
# self.logger.error(f"Trouble saving image: {e}")
|
||||
# raise Exception(f"Trouble saving image: {e}")
|
||||
|
||||
# async def save_image(
|
||||
# self,
|
||||
# blob_name: str,
|
||||
# content_type: str,
|
||||
# bytes_data: bytes
|
||||
# ) -> str:
|
||||
# """
|
||||
# Saves image bytes to Azure Blob Storage.
|
||||
|
||||
# Args:
|
||||
# blob_name (str): Name of the blob in Azure Storage.
|
||||
# content_type (str): MIME type of the content.
|
||||
# bytes_data (bytes): Image data in bytes.
|
||||
|
||||
# Returns:
|
||||
# str: URL of the uploaded blob.
|
||||
# """
|
||||
# # Retrieve environment variables
|
||||
# account = os.getenv("STORAGE_ACCOUNT")
|
||||
# access_key = os.getenv("STORAGE_ACCESS_KEY")
|
||||
# container = os.getenv("STORAGE_CONTAINER")
|
||||
|
||||
# if not all([account, access_key, container]):
|
||||
# raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
|
||||
|
||||
# # Initialize BlobServiceClient
|
||||
# blob_service_client = BlobServiceClient(
|
||||
# account_url=f"https://{account}.blob.core.windows.net",
|
||||
# credential=access_key
|
||||
# )
|
||||
# blob_client = blob_service_client.get_blob_client(container=container, blob=blob_name)
|
||||
|
||||
# # Upload the blob
|
||||
# try:
|
||||
# await blob_client.upload_blob(
|
||||
# bytes_data,
|
||||
# overwrite=True,
|
||||
# content_settings=ContentSettings(content_type=content_type)
|
||||
# )
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to upload blob: {e}")
|
||||
|
||||
# self.logger.debug(f"Blob uploaded: name={blob_name}, content_type={content_type}")
|
||||
|
||||
# # Construct and return the blob URL
|
||||
# blob_url = f"https://media.memedeck.xyz//{container}/{blob_name}"
|
||||
# return blob_url
|
||||
|
||||
# async def add_watermark(
|
||||
# self,
|
||||
# base_blob_name: str,
|
||||
# base_image: bytes
|
||||
# ) -> str:
|
||||
# """
|
||||
# Adds a watermark to the provided image and uploads the watermarked image.
|
||||
|
||||
# Args:
|
||||
# base_blob_name (str): Original blob name of the image.
|
||||
# base_image (bytes): Image data in bytes.
|
||||
|
||||
# Returns:
|
||||
# str: URL of the watermarked image.
|
||||
# """
|
||||
# # Load the input image
|
||||
# try:
|
||||
# img = Image.open(BytesIO(base_image)).convert("RGBA")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to load image: {e}")
|
||||
|
||||
# # Calculate position for the watermark (bottom right corner with padding)
|
||||
# padding = 12
|
||||
# x = img.width - WATERMARK_SIZE - padding
|
||||
# y = img.height - WATERMARK_SIZE - padding
|
||||
|
||||
# # Analyze background brightness where the watermark will be placed
|
||||
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
|
||||
# self.logger.info(f"Background brightness: {background_brightness}")
|
||||
|
||||
# # Render SVG watermark to PNG bytes using cairosvg
|
||||
# try:
|
||||
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
|
||||
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to render watermark SVG: {e}")
|
||||
|
||||
# # Determine watermark color based on background brightness
|
||||
# if background_brightness > 128:
|
||||
# # Dark watermark for light backgrounds
|
||||
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
|
||||
# else:
|
||||
# # Light watermark for dark backgrounds
|
||||
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
|
||||
|
||||
# # Apply the watermark color by blending
|
||||
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
|
||||
# watermark = Image.alpha_composite(watermark, solid_color)
|
||||
|
||||
# # Overlay the watermark onto the original image
|
||||
# img.paste(watermark, (x, y), watermark)
|
||||
|
||||
# # Save the watermarked image to bytes
|
||||
# buffer = BytesIO()
|
||||
# img = img.convert("RGB") # Convert back to RGB for JPEG format
|
||||
# img.save(buffer, format="JPEG")
|
||||
# buffer.seek(0)
|
||||
# jpeg_bytes = buffer.read()
|
||||
|
||||
# # Modify the blob name to include '_watermarked'
|
||||
# try:
|
||||
# if "memes/" in base_blob_name:
|
||||
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
|
||||
# else:
|
||||
# base_blob_name_right = base_blob_name
|
||||
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
|
||||
# base_blob_name_without_extension = base_blob_name_split[0]
|
||||
# extension = base_blob_name_split[1]
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to process blob name: {e}")
|
||||
|
||||
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
|
||||
|
||||
# # Upload the watermarked image
|
||||
# try:
|
||||
# watermarked_blob_url = await self.save_image(
|
||||
# watermarked_blob_name,
|
||||
# "image/jpeg",
|
||||
# jpeg_bytes
|
||||
# )
|
||||
# return watermarked_blob_url
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to upload watermarked image: {e}")
|
||||
|
||||
# def analyze_background_brightness(
|
||||
# self,
|
||||
# img: Image.Image,
|
||||
# x: int,
|
||||
# y: int,
|
||||
# size: int
|
||||
# ) -> int:
|
||||
# """
|
||||
# Analyzes the brightness of a specific region in the image.
|
||||
|
||||
# Args:
|
||||
# img (Image.Image): The image to analyze.
|
||||
# x (int): X-coordinate of the top-left corner of the region.
|
||||
# y (int): Y-coordinate of the top-left corner of the region.
|
||||
# size (int): Size of the square region to analyze.
|
||||
|
||||
# Returns:
|
||||
# int: Average brightness (0-255) of the region.
|
||||
# """
|
||||
# # Crop the specified region
|
||||
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
|
||||
|
||||
# # Calculate average brightness using the luminance formula
|
||||
# total_brightness = 0
|
||||
# pixel_count = 0
|
||||
# for pixel in sub_image.getdata():
|
||||
# r, g, b = pixel
|
||||
# brightness = (r * 299 + g * 587 + b * 114) // 1000
|
||||
# total_brightness += brightness
|
||||
# pixel_count += 1
|
||||
|
||||
# if pixel_count == 0:
|
||||
# return 0
|
||||
|
||||
# average_brightness = total_brightness // pixel_count
|
||||
# return average_brightness
|
||||
|
@ -1,73 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: "hidden"
|
||||
layer_idx: -2
|
@ -1,70 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
@ -1,73 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: "hidden"
|
||||
layer_idx: -2
|
@ -1,74 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: "hidden"
|
||||
layer_idx: -2
|
@ -1,71 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
@ -1,71 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 7.5e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
@ -1,68 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
@ -1,68 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: False
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
@ -1,67 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
@ -1,67 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: False
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
@ -1,158 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: hybrid
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
finetune_keys: null
|
||||
use_ema: False
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 9
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
|
||||
data:
|
||||
target: ldm.data.laion.WebDataModuleFromConfig
|
||||
params:
|
||||
tar_base: null # for concat as in LAION-A
|
||||
p_unsafe_threshold: 0.1
|
||||
filter_word_list: "data/filters.yaml"
|
||||
max_pwatermark: 0.45
|
||||
batch_size: 8
|
||||
num_workers: 6
|
||||
multinode: True
|
||||
min_size: 512
|
||||
train:
|
||||
shards:
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||
shuffle: 10000
|
||||
image_key: jpg
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.RandomCrop
|
||||
params:
|
||||
size: 512
|
||||
postprocess:
|
||||
target: ldm.data.laion.AddMask
|
||||
params:
|
||||
mode: "512train-large"
|
||||
p_drop: 0.25
|
||||
# NOTE use enough shards to avoid empty validation loops in workers
|
||||
validation:
|
||||
shards:
|
||||
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||
shuffle: 0
|
||||
image_key: jpg
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.CenterCrop
|
||||
params:
|
||||
size: 512
|
||||
postprocess:
|
||||
target: ldm.data.laion.AddMask
|
||||
params:
|
||||
mode: "512train-large"
|
||||
p_drop: 0.25
|
||||
|
||||
lightning:
|
||||
find_unused_parameters: True
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 5000
|
||||
|
||||
callbacks:
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
params:
|
||||
every_n_train_steps: 10000
|
||||
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
enable_autocast: False
|
||||
disabled: False
|
||||
batch_frequency: 1000
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 5.0
|
||||
unconditional_guidance_label: [""]
|
||||
ddim_steps: 50 # todo check these out for depth2img,
|
||||
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
val_check_interval: 5000000
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
1
pysssss-workflows/training.json
Normal file
1
pysssss-workflows/training.json
Normal file
File diff suppressed because one or more lines are too long
@ -22,3 +22,10 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
soundfile
|
||||
av
|
||||
|
||||
# memedeck dependencies
|
||||
pika
|
||||
python-dotenv
|
||||
pillow
|
||||
azure-storage-blob
|
||||
cairosvg
|
||||
|
3
styles/default.csv
Normal file
3
styles/default.csv
Normal file
@ -0,0 +1,3 @@
|
||||
name,prompt,negative_prompt
|
||||
❌Low Token,,"embedding:EasyNegative, NSFW, Cleavage, Pubic Hair, Nudity, Naked, censored"
|
||||
✅Line Art / Manga,"(Anime Scene, Toonshading, Satoshi Kon, Ken Sugimori, Hiromu Arakawa:1.2), (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows, glow effects, humorous illustration, big depth of field, Masterpiece, colors, concept art, trending on artstation, Vivid colors, dramatic",
|
|
Loading…
Reference in New Issue
Block a user