mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-05 02:52:09 +08:00
Merge branch 'master' into v3-definition
This commit is contained in:
commit
50603859ab
26
CODEOWNERS
26
CODEOWNERS
@ -5,20 +5,20 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
|
@ -205,6 +205,19 @@ comfyui-workflow-templates is not installed.
|
|||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def embedded_docs_path(cls) -> str:
|
||||||
|
"""Get the path to embedded documentation"""
|
||||||
|
try:
|
||||||
|
import comfyui_embedded_docs
|
||||||
|
|
||||||
|
return str(
|
||||||
|
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logging.info("comfyui-embedded-docs package not found")
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
|
@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
|
|||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return [1]
|
return [1]
|
||||||
|
|
||||||
|
|
||||||
|
class CONDList(CONDRegular):
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
|
out = []
|
||||||
|
for c in self.cond:
|
||||||
|
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||||
|
|
||||||
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if len(self.cond) != len(other.cond):
|
||||||
|
return False
|
||||||
|
for i in range(len(self.cond)):
|
||||||
|
if self.cond[i].shape != other.cond[i].shape:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
out = []
|
||||||
|
for i in range(len(self.cond)):
|
||||||
|
o = [self.cond[i]]
|
||||||
|
for x in others:
|
||||||
|
o.append(x.cond[i])
|
||||||
|
out.append(torch.cat(o))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def size(self): # hackish implementation to make the mem estimation work
|
||||||
|
o = 0
|
||||||
|
c = 1
|
||||||
|
for c in self.cond:
|
||||||
|
size = c.size()
|
||||||
|
o += math.prod(size)
|
||||||
|
if len(size) > 1:
|
||||||
|
c = size[1]
|
||||||
|
|
||||||
|
return [1, c, o // c]
|
||||||
|
@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||||
mod = vec
|
mod = vec
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x.addcmul_(mod.gate, output)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
@ -178,6 +176,6 @@ class LastLayer(nn.Module):
|
|||||||
shift, scale = vec
|
shift, scale = vec
|
||||||
shift = shift.squeeze(1)
|
shift = shift.squeeze(1)
|
||||||
scale = scale.squeeze(1)
|
scale = scale.squeeze(1)
|
||||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
@ -168,6 +168,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
extra = extra.to(dtype)
|
extra = extra.to(dtype)
|
||||||
|
if isinstance(extra, list):
|
||||||
|
ex = []
|
||||||
|
for ext in extra:
|
||||||
|
ex.append(ext.to(dtype))
|
||||||
|
extra = ex
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
t = self.process_timestep(t, x=x, **extra_conds)
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
|
@ -297,8 +297,13 @@ except:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
|
try:
|
||||||
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
|
except:
|
||||||
|
rocm_version = (6, -1)
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||||
|
@ -327,7 +327,9 @@ class ApiClient:
|
|||||||
ApiServerError: If the API server is unreachable but internet is working
|
ApiServerError: If the API server is unreachable but internet is working
|
||||||
Exception: For other request failures
|
Exception: For other request failures
|
||||||
"""
|
"""
|
||||||
url = urljoin(self.base_url, path)
|
# Use urljoin but ensure path is relative to avoid absolute path behavior
|
||||||
|
relative_path = path.lstrip('/')
|
||||||
|
url = urljoin(self.base_url, relative_path)
|
||||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||||
# Combine default headers with any provided headers
|
# Combine default headers with any provided headers
|
||||||
request_headers = self.get_headers()
|
request_headers = self.get_headers()
|
||||||
|
@ -14,6 +14,7 @@ import re
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy.comfy_types import FileLocator
|
||||||
|
|
||||||
@ -229,6 +230,186 @@ class SVG:
|
|||||||
all_svgs_list.extend(svg_item.data)
|
all_svgs_list.extend(svg_item.data)
|
||||||
return SVG(all_svgs_list)
|
return SVG(all_svgs_list)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageStitch:
|
||||||
|
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image1": ("IMAGE",),
|
||||||
|
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||||
|
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||||
|
"spacing_width": (
|
||||||
|
"INT",
|
||||||
|
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||||
|
),
|
||||||
|
"spacing_color": (
|
||||||
|
["white", "black", "red", "green", "blue"],
|
||||||
|
{"default": "white"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image2": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stitch"
|
||||||
|
CATEGORY = "image/transform"
|
||||||
|
DESCRIPTION = """
|
||||||
|
Stitches image2 to image1 in the specified direction.
|
||||||
|
If image2 is not provided, returns image1 unchanged.
|
||||||
|
Optional spacing can be added between images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def stitch(
|
||||||
|
self,
|
||||||
|
image1,
|
||||||
|
direction,
|
||||||
|
match_image_size,
|
||||||
|
spacing_width,
|
||||||
|
spacing_color,
|
||||||
|
image2=None,
|
||||||
|
):
|
||||||
|
if image2 is None:
|
||||||
|
return (image1,)
|
||||||
|
|
||||||
|
# Handle batch size differences
|
||||||
|
if image1.shape[0] != image2.shape[0]:
|
||||||
|
max_batch = max(image1.shape[0], image2.shape[0])
|
||||||
|
if image1.shape[0] < max_batch:
|
||||||
|
image1 = torch.cat(
|
||||||
|
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||||
|
)
|
||||||
|
if image2.shape[0] < max_batch:
|
||||||
|
image2 = torch.cat(
|
||||||
|
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match image sizes if requested
|
||||||
|
if match_image_size:
|
||||||
|
h1, w1 = image1.shape[1:3]
|
||||||
|
h2, w2 = image2.shape[1:3]
|
||||||
|
aspect_ratio = w2 / h2
|
||||||
|
|
||||||
|
if direction in ["left", "right"]:
|
||||||
|
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||||
|
else: # up, down
|
||||||
|
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||||
|
|
||||||
|
image2 = comfy.utils.common_upscale(
|
||||||
|
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||||
|
).movedim(1, -1)
|
||||||
|
|
||||||
|
# When not matching sizes, pad to align non-concat dimensions
|
||||||
|
if not match_image_size:
|
||||||
|
h1, w1 = image1.shape[1:3]
|
||||||
|
h2, w2 = image2.shape[1:3]
|
||||||
|
|
||||||
|
if direction in ["left", "right"]:
|
||||||
|
# For horizontal concat, pad heights to match
|
||||||
|
if h1 != h2:
|
||||||
|
target_h = max(h1, h2)
|
||||||
|
if h1 < target_h:
|
||||||
|
pad_h = target_h - h1
|
||||||
|
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||||
|
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||||
|
if h2 < target_h:
|
||||||
|
pad_h = target_h - h2
|
||||||
|
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||||
|
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||||
|
else: # up, down
|
||||||
|
# For vertical concat, pad widths to match
|
||||||
|
if w1 != w2:
|
||||||
|
target_w = max(w1, w2)
|
||||||
|
if w1 < target_w:
|
||||||
|
pad_w = target_w - w1
|
||||||
|
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||||
|
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||||
|
if w2 < target_w:
|
||||||
|
pad_w = target_w - w2
|
||||||
|
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||||
|
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||||
|
|
||||||
|
# Ensure same number of channels
|
||||||
|
if image1.shape[-1] != image2.shape[-1]:
|
||||||
|
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||||
|
if image1.shape[-1] < max_channels:
|
||||||
|
image1 = torch.cat(
|
||||||
|
[
|
||||||
|
image1,
|
||||||
|
torch.ones(
|
||||||
|
*image1.shape[:-1],
|
||||||
|
max_channels - image1.shape[-1],
|
||||||
|
device=image1.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
if image2.shape[-1] < max_channels:
|
||||||
|
image2 = torch.cat(
|
||||||
|
[
|
||||||
|
image2,
|
||||||
|
torch.ones(
|
||||||
|
*image2.shape[:-1],
|
||||||
|
max_channels - image2.shape[-1],
|
||||||
|
device=image2.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add spacing if specified
|
||||||
|
if spacing_width > 0:
|
||||||
|
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||||
|
|
||||||
|
color_map = {
|
||||||
|
"white": 1.0,
|
||||||
|
"black": 0.0,
|
||||||
|
"red": (1.0, 0.0, 0.0),
|
||||||
|
"green": (0.0, 1.0, 0.0),
|
||||||
|
"blue": (0.0, 0.0, 1.0),
|
||||||
|
}
|
||||||
|
color_val = color_map[spacing_color]
|
||||||
|
|
||||||
|
if direction in ["left", "right"]:
|
||||||
|
spacing_shape = (
|
||||||
|
image1.shape[0],
|
||||||
|
max(image1.shape[1], image2.shape[1]),
|
||||||
|
spacing_width,
|
||||||
|
image1.shape[-1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
spacing_shape = (
|
||||||
|
image1.shape[0],
|
||||||
|
spacing_width,
|
||||||
|
max(image1.shape[2], image2.shape[2]),
|
||||||
|
image1.shape[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||||
|
if isinstance(color_val, tuple):
|
||||||
|
for i, c in enumerate(color_val):
|
||||||
|
if i < spacing.shape[-1]:
|
||||||
|
spacing[..., i] = c
|
||||||
|
if spacing.shape[-1] == 4: # Add alpha
|
||||||
|
spacing[..., 3] = 1.0
|
||||||
|
else:
|
||||||
|
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||||
|
if spacing.shape[-1] == 4:
|
||||||
|
spacing[..., 3] = 1.0
|
||||||
|
|
||||||
|
# Concatenate images
|
||||||
|
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||||
|
if spacing_width > 0:
|
||||||
|
images.insert(1, spacing)
|
||||||
|
|
||||||
|
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||||
|
return (torch.cat(images, dim=concat_dim),)
|
||||||
|
|
||||||
|
|
||||||
class SaveSVGNode:
|
class SaveSVGNode:
|
||||||
"""
|
"""
|
||||||
Save SVG files on disk.
|
Save SVG files on disk.
|
||||||
@ -318,4 +499,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||||
"SaveSVGNode": SaveSVGNode,
|
"SaveSVGNode": SaveSVGNode,
|
||||||
|
"ImageStitch": ImageStitch,
|
||||||
}
|
}
|
||||||
|
@ -296,6 +296,41 @@ class RegexExtract():
|
|||||||
|
|
||||||
return result,
|
return result,
|
||||||
|
|
||||||
|
|
||||||
|
class RegexReplace():
|
||||||
|
DESCRIPTION = "Find and replace text using regex patterns."
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"string": (IO.STRING, {"multiline": True}),
|
||||||
|
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||||
|
"replace": (IO.STRING, {"multiline": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||||
|
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||||
|
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||||
|
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.STRING,)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "utils/string"
|
||||||
|
|
||||||
|
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||||
|
flags = 0
|
||||||
|
|
||||||
|
if case_insensitive:
|
||||||
|
flags |= re.IGNORECASE
|
||||||
|
if multiline:
|
||||||
|
flags |= re.MULTILINE
|
||||||
|
if dotall:
|
||||||
|
flags |= re.DOTALL
|
||||||
|
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||||
|
return result,
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"StringConcatenate": StringConcatenate,
|
"StringConcatenate": StringConcatenate,
|
||||||
"StringSubstring": StringSubstring,
|
"StringSubstring": StringSubstring,
|
||||||
@ -306,7 +341,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"StringContains": StringContains,
|
"StringContains": StringContains,
|
||||||
"StringCompare": StringCompare,
|
"StringCompare": StringCompare,
|
||||||
"RegexMatch": RegexMatch,
|
"RegexMatch": RegexMatch,
|
||||||
"RegexExtract": RegexExtract
|
"RegexExtract": RegexExtract,
|
||||||
|
"RegexReplace": RegexReplace,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -319,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"StringContains": "Contains",
|
"StringContains": "Contains",
|
||||||
"StringCompare": "Compare",
|
"StringCompare": "Compare",
|
||||||
"RegexMatch": "Regex Match",
|
"RegexMatch": "Regex Match",
|
||||||
"RegexExtract": "Regex Extract"
|
"RegexExtract": "Regex Extract",
|
||||||
|
"RegexReplace": "Regex Replace",
|
||||||
}
|
}
|
||||||
|
1
nodes.py
1
nodes.py
@ -2062,6 +2062,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images",
|
||||||
"ImageCrop": "Image Crop",
|
"ImageCrop": "Image Crop",
|
||||||
|
"ImageStitch": "Image Stitch",
|
||||||
"ImageBlend": "Image Blend",
|
"ImageBlend": "Image Blend",
|
||||||
"ImageBlur": "Image Blur",
|
"ImageBlur": "Image Blur",
|
||||||
"ImageQuantize": "Image Quantize",
|
"ImageQuantize": "Image Quantize",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
comfyui-frontend-package==1.20.7
|
comfyui-frontend-package==1.21.3
|
||||||
comfyui-workflow-templates==0.1.22
|
comfyui-workflow-templates==0.1.23
|
||||||
|
comfyui-embedded-docs==0.2.0
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
@ -749,6 +749,13 @@ class PromptServer():
|
|||||||
web.static('/templates', workflow_templates_path)
|
web.static('/templates', workflow_templates_path)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# Serve embedded documentation from the package
|
||||||
|
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||||
|
if embedded_docs_path:
|
||||||
|
self.app.add_routes([
|
||||||
|
web.static('/docs', embedded_docs_path)
|
||||||
|
])
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
0
tests-unit/comfy_extras_test/__init__.py
Normal file
0
tests-unit/comfy_extras_test/__init__.py
Normal file
240
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
240
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
import torch
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# Mock nodes module to prevent CUDA initialization during import
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
|
||||||
|
with patch.dict('sys.modules', {'nodes': mock_nodes}):
|
||||||
|
from comfy_extras.nodes_images import ImageStitch
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageStitch:
|
||||||
|
|
||||||
|
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
||||||
|
"""Helper to create test images with specific dimensions"""
|
||||||
|
return torch.rand(batch_size, height, width, channels)
|
||||||
|
|
||||||
|
def test_no_image2_passthrough(self):
|
||||||
|
"""Test that when image2 is None, image1 is returned unchanged"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image()
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert torch.equal(result[0], image1)
|
||||||
|
|
||||||
|
def test_basic_horizontal_stitch_right(self):
|
||||||
|
"""Test basic horizontal stitching to the right"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=24)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
||||||
|
|
||||||
|
def test_basic_horizontal_stitch_left(self):
|
||||||
|
"""Test basic horizontal stitching to the left"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=24)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "left", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
||||||
|
|
||||||
|
def test_basic_vertical_stitch_down(self):
|
||||||
|
"""Test basic vertical stitching downward"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=24, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
||||||
|
|
||||||
|
def test_basic_vertical_stitch_up(self):
|
||||||
|
"""Test basic vertical stitching upward"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=24, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "up", False, 0, "white", image2)
|
||||||
|
|
||||||
|
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
||||||
|
|
||||||
|
def test_size_matching_horizontal(self):
|
||||||
|
"""Test size matching for horizontal concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=64, width=64)
|
||||||
|
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", True, 0, "white", image2)
|
||||||
|
|
||||||
|
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
||||||
|
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
||||||
|
assert result[0].shape == (1, 64, expected_width, 3)
|
||||||
|
|
||||||
|
def test_size_matching_vertical(self):
|
||||||
|
"""Test size matching for vertical concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=64, width=64)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", True, 0, "white", image2)
|
||||||
|
|
||||||
|
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
||||||
|
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
||||||
|
assert result[0].shape == (1, expected_height, 64, 3)
|
||||||
|
|
||||||
|
def test_padding_for_mismatched_heights_horizontal(self):
|
||||||
|
"""Test padding when heights don't match in horizontal concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=64, width=32)
|
||||||
|
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Both images should be padded to height 64
|
||||||
|
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
||||||
|
|
||||||
|
def test_padding_for_mismatched_widths_vertical(self):
|
||||||
|
"""Test padding when widths don't match in vertical concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=64)
|
||||||
|
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Both images should be padded to width 64
|
||||||
|
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
||||||
|
|
||||||
|
def test_spacing_horizontal(self):
|
||||||
|
"""Test spacing addition in horizontal concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=24)
|
||||||
|
spacing_width = 16
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
||||||
|
|
||||||
|
# Expected width: 32 + 16 (spacing) + 24 = 72
|
||||||
|
assert result[0].shape == (1, 32, 72, 3)
|
||||||
|
|
||||||
|
def test_spacing_vertical(self):
|
||||||
|
"""Test spacing addition in vertical concatenation"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=24, width=32)
|
||||||
|
spacing_width = 16
|
||||||
|
|
||||||
|
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
||||||
|
|
||||||
|
# Expected height: 32 + 16 (spacing) + 24 = 72
|
||||||
|
assert result[0].shape == (1, 72, 32, 3)
|
||||||
|
|
||||||
|
def test_spacing_color_values(self):
|
||||||
|
"""Test that spacing colors are applied correctly"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
# Test white spacing
|
||||||
|
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
||||||
|
# Check that spacing region contains white values (close to 1.0)
|
||||||
|
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
||||||
|
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
||||||
|
|
||||||
|
# Test black spacing
|
||||||
|
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
||||||
|
spacing_region = result_black[0][:, :, 32:48, :]
|
||||||
|
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
||||||
|
|
||||||
|
def test_odd_spacing_width_made_even(self):
|
||||||
|
"""Test that odd spacing widths are made even"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
# Use odd spacing width
|
||||||
|
result = node.stitch(image1, "right", False, 15, "white", image2)
|
||||||
|
|
||||||
|
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
||||||
|
assert result[0].shape == (1, 32, 80, 3)
|
||||||
|
|
||||||
|
def test_batch_size_matching(self):
|
||||||
|
"""Test that different batch sizes are handled correctly"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
||||||
|
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Should match larger batch size
|
||||||
|
assert result[0].shape == (2, 32, 64, 3)
|
||||||
|
|
||||||
|
def test_channel_matching_rgb_to_rgba(self):
|
||||||
|
"""Test that channel differences are handled (RGB + alpha)"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(channels=3) # RGB
|
||||||
|
image2 = self.create_test_image(channels=4) # RGBA
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Should have 4 channels (RGBA)
|
||||||
|
assert result[0].shape[-1] == 4
|
||||||
|
|
||||||
|
def test_channel_matching_rgba_to_rgb(self):
|
||||||
|
"""Test that channel differences are handled (RGBA + RGB)"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(channels=4) # RGBA
|
||||||
|
image2 = self.create_test_image(channels=3) # RGB
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||||
|
|
||||||
|
# Should have 4 channels (RGBA)
|
||||||
|
assert result[0].shape[-1] == 4
|
||||||
|
|
||||||
|
def test_all_color_options(self):
|
||||||
|
"""Test all available color options"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
colors = ["white", "black", "red", "green", "blue"]
|
||||||
|
|
||||||
|
for color in colors:
|
||||||
|
result = node.stitch(image1, "right", False, 16, color, image2)
|
||||||
|
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
||||||
|
|
||||||
|
def test_all_directions(self):
|
||||||
|
"""Test all direction options"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(height=32, width=32)
|
||||||
|
image2 = self.create_test_image(height=32, width=32)
|
||||||
|
|
||||||
|
directions = ["right", "left", "up", "down"]
|
||||||
|
|
||||||
|
for direction in directions:
|
||||||
|
result = node.stitch(image1, direction, False, 0, "white", image2)
|
||||||
|
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
||||||
|
|
||||||
|
def test_batch_size_channel_spacing_integration(self):
|
||||||
|
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
||||||
|
node = ImageStitch()
|
||||||
|
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
||||||
|
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
||||||
|
|
||||||
|
result = node.stitch(image1, "right", True, 8, "red", image2)
|
||||||
|
|
||||||
|
# Should handle: batch matching, size matching, channel matching, spacing
|
||||||
|
assert result[0].shape[0] == 2 # Batch size matched
|
||||||
|
assert result[0].shape[-1] == 4 # Channels matched to max
|
||||||
|
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
||||||
|
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
||||||
|
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
||||||
|
expected_total_width = 48 + 8 + expected_image2_width
|
||||||
|
assert result[0].shape[2] == expected_total_width
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user