ComfyUI/comfy_extras/nodes_images.py
Christian Byrne d062fcc5c0
[feat] Add ImageStitch node for concatenating images (#8369)
* [feat] Add ImageStitch node for concatenating images with borders

Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing.

Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage.

* [fix] Fix CI issues with CUDA dependencies and linting

- Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners
- Fix ruff linting issues for code style compliance

* [fix] Improve CI compatibility by mocking nodes module import

Prevent CUDA initialization chain by mocking the nodes module at import time,
which is cleaner than deep mocking of CUDA-specific functions.

* [refactor] Clean up ImageStitch tests

- Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini)
- Remove metadata tests that test framework internals rather than functionality
- Rename complex scenario test to be more descriptive of what it tests

* [refactor] Rename 'border' to 'spacing' for semantic accuracy

- Change border_width/border_color to spacing_width/spacing_color in API
- Update all tests to use spacing terminology
- Update comments and variable names throughout
- More accurately describes the gap/separator between images
2025-06-01 04:28:52 -04:00

504 lines
18 KiB
Python

from __future__ import annotations
import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
CATEGORY = "image/transform"
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
class ImageFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch"
CATEGORY = "image/batch"
def frombatch(self, image, batch_index, length):
s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone()
return (s,)
class ImageAddNoise:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image"
def repeat(self, image, seed, strength):
generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results: list[FileLocator] = []
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch:
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
if image2 is None:
return (image1,)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
max_batch = max(image1.shape[0], image2.shape[0])
if image1.shape[0] < max_batch:
image1 = torch.cat(
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
)
if image2.shape[0] < max_batch:
image2 = torch.cat(
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
)
# Match image sizes if requested
if match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
aspect_ratio = w2 / h2
if direction in ["left", "right"]:
target_h, target_w = h1, int(h1 * aspect_ratio)
else: # up, down
target_w, target_h = w1, int(w1 / aspect_ratio)
image2 = comfy.utils.common_upscale(
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
if h1 != h2:
target_h = max(h1, h2)
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
target_w = max(w1, w2)
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat(
[
image1,
torch.ones(
*image1.shape[:-1],
max_channels - image1.shape[-1],
device=image1.device,
),
],
dim=-1,
)
if image2.shape[-1] < max_channels:
image2 = torch.cat(
[
image2,
torch.ones(
*image2.shape[:-1],
max_channels - image2.shape[-1],
device=image2.device,
),
],
dim=-1,
)
# Add spacing if specified
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
max(image1.shape[1], image2.shape[1]),
spacing_width,
image1.shape[-1],
)
else:
spacing_shape = (
image1.shape[0],
spacing_width,
max(image1.shape[2], image2.shape[2]),
image1.shape[-1],
)
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
if isinstance(color_val, tuple):
for i, c in enumerate(color_val):
if i < spacing.shape[-1]:
spacing[..., i] = c
if spacing.shape[-1] == 4: # Add alpha
spacing[..., 3] = 1.0
else:
spacing[..., : min(3, spacing.shape[-1])] = color_val
if spacing.shape[-1] == 4:
spacing[..., 3] = 1.0
# Concatenate images
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
if spacing_width > 0:
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": ("SVG",), # Changed
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
# Read SVG content
svg_bytes.seek(0)
svg_content = svg_bytes.read().decode('utf-8')
# Inject metadata if available
if metadata_json:
# Create metadata element with CDATA section
metadata_element = f""" <metadata>
<![CDATA[
{metadata_json}
]]>
</metadata>
"""
# Insert metadata after opening svg tag using regex with a replacement function
def replacement(match):
# match.group(1) contains the captured <svg> tag
return match.group(1) + '\n' + metadata_element
# Apply the substitution
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
# Write the modified SVG to file
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"ImageAddNoise": ImageAddNoise,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
}