mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-17 09:33:29 +00:00
Merge remote-tracking branch 'origin/master' into multigpu_support
This commit is contained in:
commit
0b3233b4e2
@ -15,6 +15,7 @@
|
|||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
# Frontend assets
|
# Frontend assets
|
||||||
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
||||||
|
@ -154,9 +154,9 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
|
||||||
|
|
||||||
### Intel GPUs (Windows and Linux)
|
### Intel GPUs (Windows and Linux)
|
||||||
|
|
||||||
|
@ -4,12 +4,93 @@ import os
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import glob
|
import glob
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from utils.json_util import merge_json_recursive
|
||||||
|
|
||||||
|
|
||||||
|
# Extra locale files to load into main.json
|
||||||
|
EXTRA_LOCALE_FILES = [
|
||||||
|
"nodeDefs.json",
|
||||||
|
"commands.json",
|
||||||
|
"settings.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def safe_load_json_file(file_path: str) -> dict:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Error loading {file_path}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class CustomNodeManager:
|
class CustomNodeManager:
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def build_translations(self):
|
||||||
|
"""Load all custom nodes translations during initialization. Translations are
|
||||||
|
expected to be loaded from `locales/` folder.
|
||||||
|
|
||||||
|
The folder structure is expected to be the following:
|
||||||
|
- custom_nodes/
|
||||||
|
- custom_node_1/
|
||||||
|
- locales/
|
||||||
|
- en/
|
||||||
|
- main.json
|
||||||
|
- commands.json
|
||||||
|
- settings.json
|
||||||
|
|
||||||
|
returned translations are expected to be in the following format:
|
||||||
|
{
|
||||||
|
"en": {
|
||||||
|
"nodeDefs": {...},
|
||||||
|
"commands": {...},
|
||||||
|
"settings": {...},
|
||||||
|
...{other main.json keys}
|
||||||
|
}
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
Placeholder to refactor the custom node management features from ComfyUI-Manager.
|
|
||||||
Currently it only contains the custom workflow templates feature.
|
translations = {}
|
||||||
"""
|
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
|
# Sort glob results for deterministic ordering
|
||||||
|
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
||||||
|
locales_dir = os.path.join(custom_node_dir, "locales")
|
||||||
|
if not os.path.exists(locales_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
||||||
|
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
||||||
|
|
||||||
|
if lang_code not in translations:
|
||||||
|
translations[lang_code] = {}
|
||||||
|
|
||||||
|
# Load main.json
|
||||||
|
main_file = os.path.join(lang_dir, "main.json")
|
||||||
|
node_translations = safe_load_json_file(main_file)
|
||||||
|
|
||||||
|
# Load extra locale files
|
||||||
|
for extra_file in EXTRA_LOCALE_FILES:
|
||||||
|
extra_file_path = os.path.join(lang_dir, extra_file)
|
||||||
|
key = extra_file.split(".")[0]
|
||||||
|
json_data = safe_load_json_file(extra_file_path)
|
||||||
|
if json_data:
|
||||||
|
node_translations[key] = json_data
|
||||||
|
|
||||||
|
if node_translations:
|
||||||
|
translations[lang_code] = merge_json_recursive(
|
||||||
|
translations[lang_code], node_translations
|
||||||
|
)
|
||||||
|
|
||||||
|
return translations
|
||||||
|
|
||||||
def add_routes(self, routes, webapp, loadedModules):
|
def add_routes(self, routes, webapp, loadedModules):
|
||||||
|
|
||||||
@routes.get("/workflow_templates")
|
@routes.get("/workflow_templates")
|
||||||
@ -18,17 +99,36 @@ class CustomNodeManager:
|
|||||||
files = [
|
files = [
|
||||||
file
|
file
|
||||||
for folder in folder_paths.get_folder_paths("custom_nodes")
|
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||||
for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
|
for file in glob.glob(
|
||||||
|
os.path.join(folder, "*/example_workflows/*.json")
|
||||||
|
)
|
||||||
]
|
]
|
||||||
workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
|
workflow_templates_dict = (
|
||||||
|
{}
|
||||||
|
) # custom_nodes folder name -> example workflow names
|
||||||
for file in files:
|
for file in files:
|
||||||
custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
|
custom_nodes_name = os.path.basename(
|
||||||
|
os.path.dirname(os.path.dirname(file))
|
||||||
|
)
|
||||||
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
||||||
workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
|
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
||||||
|
workflow_name
|
||||||
|
)
|
||||||
return web.json_response(workflow_templates_dict)
|
return web.json_response(workflow_templates_dict)
|
||||||
|
|
||||||
# Serve workflow templates from custom nodes.
|
# Serve workflow templates from custom nodes.
|
||||||
for module_name, module_dir in loadedModules:
|
for module_name, module_dir in loadedModules:
|
||||||
workflows_dir = os.path.join(module_dir, 'example_workflows')
|
workflows_dir = os.path.join(module_dir, "example_workflows")
|
||||||
if os.path.exists(workflows_dir):
|
if os.path.exists(workflows_dir):
|
||||||
webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])
|
webapp.add_routes(
|
||||||
|
[
|
||||||
|
web.static(
|
||||||
|
"/api/workflow_templates/" + module_name, workflows_dir
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@routes.get("/i18n")
|
||||||
|
async def get_i18n(request):
|
||||||
|
"""Returns translations from all custom nodes' locales folders."""
|
||||||
|
return web.json_response(self.build_translations())
|
||||||
|
@ -3,9 +3,6 @@ import math
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|
||||||
return abs(a*b) // math.gcd(a, b)
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@ -46,7 +43,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
return False
|
return False
|
||||||
|
|
||||||
mult_min = lcm(s1[1], s2[1])
|
mult_min = math.lcm(s1[1], s2[1])
|
||||||
diff = mult_min // min(s1[1], s2[1])
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
return False
|
return False
|
||||||
@ -57,7 +54,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
crossattn_max_len = self.cond.shape[1]
|
crossattn_max_len = self.cond.shape[1]
|
||||||
for x in others:
|
for x in others:
|
||||||
c = x.cond
|
c = x.cond
|
||||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
|
||||||
conds.append(c)
|
conds.append(c)
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
|
@ -4,105 +4,6 @@ import logging
|
|||||||
|
|
||||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
# =================#
|
|
||||||
# UNet Conversion #
|
|
||||||
# =================#
|
|
||||||
|
|
||||||
unet_conversion_map = [
|
|
||||||
# (stable-diffusion, HF Diffusers)
|
|
||||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
|
||||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
|
||||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
|
||||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
|
||||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
|
||||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
|
||||||
("out.0.weight", "conv_norm_out.weight"),
|
|
||||||
("out.0.bias", "conv_norm_out.bias"),
|
|
||||||
("out.2.weight", "conv_out.weight"),
|
|
||||||
("out.2.bias", "conv_out.bias"),
|
|
||||||
]
|
|
||||||
|
|
||||||
unet_conversion_map_resnet = [
|
|
||||||
# (stable-diffusion, HF Diffusers)
|
|
||||||
("in_layers.0", "norm1"),
|
|
||||||
("in_layers.2", "conv1"),
|
|
||||||
("out_layers.0", "norm2"),
|
|
||||||
("out_layers.3", "conv2"),
|
|
||||||
("emb_layers.1", "time_emb_proj"),
|
|
||||||
("skip_connection", "conv_shortcut"),
|
|
||||||
]
|
|
||||||
|
|
||||||
unet_conversion_map_layer = []
|
|
||||||
# hardcoded number of downblocks and resnets/attentions...
|
|
||||||
# would need smarter logic for other networks.
|
|
||||||
for i in range(4):
|
|
||||||
# loop over downblocks/upblocks
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
# loop over resnets/attentions for downblocks
|
|
||||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
|
||||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no attention layers in down_blocks.3
|
|
||||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
|
||||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(3):
|
|
||||||
# loop over resnets/attentions for upblocks
|
|
||||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
|
||||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
|
||||||
|
|
||||||
if i > 0:
|
|
||||||
# no attention layers in up_blocks.0
|
|
||||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
|
||||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no downsample in down_blocks.3
|
|
||||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
|
||||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
|
||||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
|
||||||
|
|
||||||
# no upsample in up_blocks.3
|
|
||||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
|
||||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
|
||||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
|
||||||
|
|
||||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
|
||||||
sd_mid_atn_prefix = "middle_block.1."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
|
||||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
|
||||||
|
|
||||||
|
|
||||||
def convert_unet_state_dict(unet_state_dict):
|
|
||||||
# buyer beware: this is a *brittle* function,
|
|
||||||
# and correct output requires that all of these pieces interact in
|
|
||||||
# the exact order in which I have arranged them.
|
|
||||||
mapping = {k: k for k in unet_state_dict.keys()}
|
|
||||||
for sd_name, hf_name in unet_conversion_map:
|
|
||||||
mapping[hf_name] = sd_name
|
|
||||||
for k, v in mapping.items():
|
|
||||||
if "resnets" in k:
|
|
||||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
|
||||||
v = v.replace(hf_part, sd_part)
|
|
||||||
mapping[k] = v
|
|
||||||
for k, v in mapping.items():
|
|
||||||
for sd_part, hf_part in unet_conversion_map_layer:
|
|
||||||
v = v.replace(hf_part, sd_part)
|
|
||||||
mapping[k] = v
|
|
||||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
|
||||||
return new_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
# ================#
|
# ================#
|
||||||
# VAE Conversion #
|
# VAE Conversion #
|
||||||
# ================#
|
# ================#
|
||||||
@ -213,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||||
def cat_tensors(tensors):
|
def cat_tensors(tensors):
|
||||||
x = 0
|
x = 0
|
||||||
@ -229,6 +131,7 @@ def cat_tensors(tensors):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
capture_qkv_weight = {}
|
capture_qkv_weight = {}
|
||||||
@ -284,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|||||||
|
|
||||||
def convert_text_enc_state_dict(text_enc_dict):
|
def convert_text_enc_state_dict(text_enc_dict):
|
||||||
return text_enc_dict
|
return text_enc_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -1336,3 +1336,26 @@ def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disab
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
old_d = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
d = to_d(x, sigmas[i], denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
|
if i == 0:
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Gradient estimation
|
||||||
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||||
|
x = x + d_bar * dt
|
||||||
|
old_d = d
|
||||||
|
return x
|
||||||
|
@ -109,8 +109,7 @@ class Flux(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is not None:
|
||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
@ -186,7 +185,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
@ -240,8 +240,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is not None:
|
||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||||
@ -314,7 +313,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = img.reshape(initial_shape)
|
img = img.reshape(initial_shape)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
@ -702,9 +702,6 @@ class Decoder(nn.Module):
|
|||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, z, **kwargs):
|
def forward(self, z, **kwargs):
|
||||||
#assert z.shape[1:] == self.z_shape[1:]
|
|
||||||
self.last_z_shape = z.shape
|
|
||||||
|
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
|
@ -148,7 +148,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
|
if context is not None:
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
|
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
@ -549,6 +551,10 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
|
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class IP2P:
|
class IP2P:
|
||||||
@ -806,7 +812,10 @@ class Flux(BaseModel):
|
|||||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
|
||||||
|
guidance = kwargs.get("guidance", 3.5)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class GenmoMochi(BaseModel):
|
class GenmoMochi(BaseModel):
|
||||||
@ -863,7 +872,10 @@ class HunyuanVideo(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
|
||||||
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class CosmosVideo(BaseModel):
|
class CosmosVideo(BaseModel):
|
||||||
|
@ -237,7 +237,7 @@ def is_amd():
|
|||||||
|
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.2
|
MIN_WEIGHT_MEMORY_RATIO = 0.1
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
if args.use_pytorch_cross_attention:
|
if args.use_pytorch_cross_attention:
|
||||||
@ -554,14 +554,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
vram_set_state = vram_state
|
vram_set_state = vram_state
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
|
||||||
loaded_memory = loaded_model.model_loaded_memory()
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||||
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
|
||||||
lowvram_model_memory = 0
|
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 0.1
|
lowvram_model_memory = 0.1
|
||||||
|
@ -60,7 +60,6 @@ def convert_cond(cond):
|
|||||||
temp = c[1].copy()
|
temp = c[1].copy()
|
||||||
model_conds = temp.get("model_conds", {})
|
model_conds = temp.get("model_conds", {})
|
||||||
if c[0] is not None:
|
if c[0] is not None:
|
||||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
|
||||||
temp["cross_attn"] = c[0]
|
temp["cross_attn"] = c[0]
|
||||||
temp["model_conds"] = model_conds
|
temp["model_conds"] = model_conds
|
||||||
temp["uuid"] = uuid.uuid4()
|
temp["uuid"] = uuid.uuid4()
|
||||||
|
@ -879,7 +879,7 @@ class Sampler:
|
|||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"]
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
@ -38,7 +38,26 @@ class FluxGuidance:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
class FluxDisableGuidance:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"conditioning": ("CONDITIONING", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/flux"
|
||||||
|
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
|
||||||
|
|
||||||
|
def append(self, conditioning):
|
||||||
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||||
"FluxGuidance": FluxGuidance,
|
"FluxGuidance": FluxGuidance,
|
||||||
|
"FluxDisableGuidance": FluxDisableGuidance,
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,14 @@ import comfy.utils
|
|||||||
import comfy_extras.nodes_post_processing
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
|
||||||
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
|
||||||
|
if repeat_batch:
|
||||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||||
|
else:
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
class LatentAdd:
|
class LatentAdd:
|
||||||
@ -116,8 +120,7 @@ class LatentBatch:
|
|||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
s2 = samples2["samples"]
|
s2 = samples2["samples"]
|
||||||
|
|
||||||
if s1.shape[1:] != s2.shape[1:]:
|
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
|
||||||
s2 = comfy.utils.common_upscale(s2, s1.shape[-1], s1.shape[-2], "bilinear", "center")
|
|
||||||
s = torch.cat((s1, s2), dim=0)
|
s = torch.cat((s1, s2), dim=0)
|
||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||||
|
@ -19,9 +19,6 @@ class Load3D():
|
|||||||
"image": ("LOAD_3D", {}),
|
"image": ("LOAD_3D", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"show_grid": ([True, False],),
|
|
||||||
"camera_type": (["perspective", "orthographic"],),
|
|
||||||
"view": (["front", "right", "top", "isometric"],),
|
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
"material": (["original", "normal", "wireframe", "depth"],),
|
||||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||||
@ -69,9 +66,6 @@ class Load3DAnimation():
|
|||||||
"image": ("LOAD_3D_ANIMATION", {}),
|
"image": ("LOAD_3D_ANIMATION", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"show_grid": ([True, False],),
|
|
||||||
"camera_type": (["perspective", "orthographic"],),
|
|
||||||
"view": (["front", "right", "top", "isometric"],),
|
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
"material": (["original", "normal", "wireframe", "depth"],),
|
||||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||||
@ -109,9 +103,6 @@ class Preview3D():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"show_grid": ([True, False],),
|
|
||||||
"camera_type": (["perspective", "orthographic"],),
|
|
||||||
"view": (["front", "right", "top", "isometric"],),
|
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
"material": (["original", "normal", "wireframe", "depth"],),
|
||||||
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
|
||||||
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
|
||||||
|
@ -39,10 +39,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
|
|||||||
|
|
||||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||||
|
|
||||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
output_directory = os.path.join(base_path, "output")
|
||||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_directory = os.path.join(base_path, "temp")
|
||||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
input_directory = os.path.join(base_path, "input")
|
||||||
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
user_directory = os.path.join(base_path, "user")
|
||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
3
main.py
3
main.py
@ -138,6 +138,8 @@ import server
|
|||||||
from server import BinaryEventTypes
|
from server import BinaryEventTypes
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfyui_version
|
||||||
|
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
@ -292,6 +294,7 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
event_loop.run_until_complete(start_all_func())
|
event_loop.run_until_complete(start_all_func())
|
||||||
|
4
nodes.py
4
nodes.py
@ -63,6 +63,8 @@ class CLIPTextEncode(ComfyNodeABC):
|
|||||||
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||||
|
|
||||||
def encode(self, clip, text):
|
def encode(self, clip, text):
|
||||||
|
if clip is None:
|
||||||
|
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
@ -937,6 +939,8 @@ class CLIPLoader:
|
|||||||
clip_type = comfy.sd.CLIPType.LTXV
|
clip_type = comfy.sd.CLIPType.LTXV
|
||||||
elif type == "pixart":
|
elif type == "pixart":
|
||||||
clip_type = comfy.sd.CLIPType.PIXART
|
clip_type = comfy.sd.CLIPType.PIXART
|
||||||
|
elif type == "cosmos":
|
||||||
|
clip_type = comfy.sd.CLIPType.COSMOS
|
||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
|
@ -2,39 +2,146 @@ import pytest
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
|
import json
|
||||||
|
|
||||||
pytestmark = (
|
pytestmark = (
|
||||||
pytest.mark.asyncio
|
pytest.mark.asyncio
|
||||||
) # This applies the asyncio mark to all test functions in the module
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def custom_node_manager():
|
def custom_node_manager():
|
||||||
return CustomNodeManager()
|
return CustomNodeManager()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(custom_node_manager):
|
def app(custom_node_manager):
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")])
|
custom_node_manager.add_routes(
|
||||||
|
routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")]
|
||||||
|
)
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
|
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
# Setup temporary custom nodes file structure with 1 workflow file
|
# Setup temporary custom nodes file structure with 1 workflow file
|
||||||
custom_nodes_dir = tmp_path / "custom_nodes"
|
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||||
example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
|
example_workflows_dir = (
|
||||||
|
custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
|
||||||
|
)
|
||||||
example_workflows_dir.mkdir(parents=True)
|
example_workflows_dir.mkdir(parents=True)
|
||||||
template_file = example_workflows_dir / "workflow1.json"
|
template_file = example_workflows_dir / "workflow1.json"
|
||||||
template_file.write_text('')
|
template_file.write_text("")
|
||||||
|
|
||||||
with patch('folder_paths.folder_names_and_paths', {
|
with patch(
|
||||||
'custom_nodes': ([str(custom_nodes_dir)], None)
|
"folder_paths.folder_names_and_paths",
|
||||||
}):
|
{"custom_nodes": ([str(custom_nodes_dir)], None)},
|
||||||
response = await client.get('/workflow_templates')
|
):
|
||||||
|
response = await client.get("/workflow_templates")
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
workflows_dict = await response.json()
|
workflows_dict = await response.json()
|
||||||
assert isinstance(workflows_dict, dict)
|
assert isinstance(workflows_dict, dict)
|
||||||
assert "ComfyUI-TestExtension1" in workflows_dict
|
assert "ComfyUI-TestExtension1" in workflows_dict
|
||||||
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
|
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
|
||||||
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
|
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_empty_when_no_locales(custom_node_manager, tmp_path):
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||||
|
custom_nodes_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
assert translations == {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_loads_all_files(custom_node_manager, tmp_path):
|
||||||
|
# Setup test directory structure
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
|
||||||
|
locales_dir = custom_nodes_dir / "locales" / "en"
|
||||||
|
locales_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create test translation files
|
||||||
|
main_content = {"title": "Test Extension"}
|
||||||
|
(locales_dir / "main.json").write_text(json.dumps(main_content))
|
||||||
|
|
||||||
|
node_defs = {"node1": "Node 1"}
|
||||||
|
(locales_dir / "nodeDefs.json").write_text(json.dumps(node_defs))
|
||||||
|
|
||||||
|
commands = {"cmd1": "Command 1"}
|
||||||
|
(locales_dir / "commands.json").write_text(json.dumps(commands))
|
||||||
|
|
||||||
|
settings = {"setting1": "Setting 1"}
|
||||||
|
(locales_dir / "settings.json").write_text(json.dumps(settings))
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
|
||||||
|
):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
|
||||||
|
assert translations == {
|
||||||
|
"en": {
|
||||||
|
"title": "Test Extension",
|
||||||
|
"nodeDefs": {"node1": "Node 1"},
|
||||||
|
"commands": {"cmd1": "Command 1"},
|
||||||
|
"settings": {"setting1": "Setting 1"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_handles_invalid_json(custom_node_manager, tmp_path):
|
||||||
|
# Setup test directory structure
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
|
||||||
|
locales_dir = custom_nodes_dir / "locales" / "en"
|
||||||
|
locales_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create valid main.json
|
||||||
|
main_content = {"title": "Test Extension"}
|
||||||
|
(locales_dir / "main.json").write_text(json.dumps(main_content))
|
||||||
|
|
||||||
|
# Create invalid JSON file
|
||||||
|
(locales_dir / "nodeDefs.json").write_text("invalid json{")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
|
||||||
|
):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
|
||||||
|
assert translations == {
|
||||||
|
"en": {
|
||||||
|
"title": "Test Extension",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_build_translations_merges_multiple_extensions(
|
||||||
|
custom_node_manager, tmp_path
|
||||||
|
):
|
||||||
|
# Setup test directory structure for two extensions
|
||||||
|
custom_nodes_dir = tmp_path / "custom_nodes"
|
||||||
|
ext1_dir = custom_nodes_dir / "extension1" / "locales" / "en"
|
||||||
|
ext2_dir = custom_nodes_dir / "extension2" / "locales" / "en"
|
||||||
|
ext1_dir.mkdir(parents=True)
|
||||||
|
ext2_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create translation files for extension 1
|
||||||
|
ext1_main = {"title": "Extension 1", "shared": "Original"}
|
||||||
|
(ext1_dir / "main.json").write_text(json.dumps(ext1_main))
|
||||||
|
|
||||||
|
# Create translation files for extension 2
|
||||||
|
ext2_main = {"description": "Extension 2", "shared": "Override"}
|
||||||
|
(ext2_dir / "main.json").write_text(json.dumps(ext2_main))
|
||||||
|
|
||||||
|
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
|
||||||
|
translations = custom_node_manager.build_translations()
|
||||||
|
|
||||||
|
assert translations == {
|
||||||
|
"en": {
|
||||||
|
"title": "Extension 1",
|
||||||
|
"description": "Extension 2",
|
||||||
|
"shared": "Override", # Second extension should override first
|
||||||
|
}
|
||||||
|
}
|
||||||
|
71
tests-unit/utils/json_util_test.py
Normal file
71
tests-unit/utils/json_util_test.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from utils.json_util import merge_json_recursive
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_simple_dicts():
|
||||||
|
base = {"a": 1, "b": 2}
|
||||||
|
update = {"b": 3, "c": 4}
|
||||||
|
expected = {"a": 1, "b": 3, "c": 4}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_nested_dicts():
|
||||||
|
base = {"a": {"x": 1, "y": 2}, "b": 3}
|
||||||
|
update = {"a": {"y": 4, "z": 5}}
|
||||||
|
expected = {"a": {"x": 1, "y": 4, "z": 5}, "b": 3}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lists():
|
||||||
|
base = {"a": [1, 2], "b": 3}
|
||||||
|
update = {"a": [3, 4]}
|
||||||
|
expected = {"a": [1, 2, 3, 4], "b": 3}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_nested_lists():
|
||||||
|
base = {"a": {"x": [1, 2]}}
|
||||||
|
update = {"a": {"x": [3, 4]}}
|
||||||
|
expected = {"a": {"x": [1, 2, 3, 4]}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_mixed_types():
|
||||||
|
base = {"a": [1, 2], "b": {"x": 1}}
|
||||||
|
update = {"a": [3], "b": {"y": 2}}
|
||||||
|
expected = {"a": [1, 2, 3], "b": {"x": 1, "y": 2}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_overwrite_non_dict():
|
||||||
|
base = {"a": 1}
|
||||||
|
update = {"a": {"x": 2}}
|
||||||
|
expected = {"a": {"x": 2}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_dicts():
|
||||||
|
base = {}
|
||||||
|
update = {"a": 1}
|
||||||
|
expected = {"a": 1}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_none_values():
|
||||||
|
base = {"a": None}
|
||||||
|
update = {"a": {"x": 1}}
|
||||||
|
expected = {"a": {"x": 1}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_different_types():
|
||||||
|
base = {"a": [1, 2]}
|
||||||
|
update = {"a": "string"}
|
||||||
|
expected = {"a": "string"}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_complex_nested():
|
||||||
|
base = {"a": [1, 2], "b": {"x": [3, 4], "y": {"p": 1}}}
|
||||||
|
update = {"a": [5], "b": {"x": [6], "y": {"q": 2}}}
|
||||||
|
expected = {"a": [1, 2, 5], "b": {"x": [3, 4, 6], "y": {"p": 1, "q": 2}}}
|
||||||
|
assert merge_json_recursive(base, update) == expected
|
26
utils/json_util.py
Normal file
26
utils/json_util.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
def merge_json_recursive(base, update):
|
||||||
|
"""Recursively merge two JSON-like objects.
|
||||||
|
- Dictionaries are merged recursively
|
||||||
|
- Lists are concatenated
|
||||||
|
- Other types are overwritten by the update value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base: Base JSON-like object
|
||||||
|
update: Update JSON-like object to merge into base
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged JSON-like object
|
||||||
|
"""
|
||||||
|
if not isinstance(base, dict) or not isinstance(update, dict):
|
||||||
|
if isinstance(base, list) and isinstance(update, list):
|
||||||
|
return base + update
|
||||||
|
return update
|
||||||
|
|
||||||
|
merged = base.copy()
|
||||||
|
for key, value in update.items():
|
||||||
|
if key in merged:
|
||||||
|
merged[key] = merge_json_recursive(merged[key], value)
|
||||||
|
else:
|
||||||
|
merged[key] = value
|
||||||
|
|
||||||
|
return merged
|
Loading…
Reference in New Issue
Block a user