[feat] Add ImageStitch node for concatenating images (#8369)

* [feat] Add ImageStitch node for concatenating images with borders

Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing.

Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage.

* [fix] Fix CI issues with CUDA dependencies and linting

- Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners
- Fix ruff linting issues for code style compliance

* [fix] Improve CI compatibility by mocking nodes module import

Prevent CUDA initialization chain by mocking the nodes module at import time,
which is cleaner than deep mocking of CUDA-specific functions.

* [refactor] Clean up ImageStitch tests

- Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini)
- Remove metadata tests that test framework internals rather than functionality
- Rename complex scenario test to be more descriptive of what it tests

* [refactor] Rename 'border' to 'spacing' for semantic accuracy

- Change border_width/border_color to spacing_width/spacing_color in API
- Update all tests to use spacing terminology
- Update comments and variable names throughout
- More accurately describes the gap/separator between images
This commit is contained in:
Christian Byrne 2025-06-01 01:28:52 -07:00 committed by GitHub
parent 456abad834
commit d062fcc5c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 423 additions and 0 deletions

View File

@ -14,6 +14,7 @@ import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator
@ -229,6 +230,186 @@ class SVG:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch:
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
if image2 is None:
return (image1,)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
max_batch = max(image1.shape[0], image2.shape[0])
if image1.shape[0] < max_batch:
image1 = torch.cat(
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
)
if image2.shape[0] < max_batch:
image2 = torch.cat(
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
)
# Match image sizes if requested
if match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
aspect_ratio = w2 / h2
if direction in ["left", "right"]:
target_h, target_w = h1, int(h1 * aspect_ratio)
else: # up, down
target_w, target_h = w1, int(w1 / aspect_ratio)
image2 = comfy.utils.common_upscale(
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
if h1 != h2:
target_h = max(h1, h2)
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
target_w = max(w1, w2)
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat(
[
image1,
torch.ones(
*image1.shape[:-1],
max_channels - image1.shape[-1],
device=image1.device,
),
],
dim=-1,
)
if image2.shape[-1] < max_channels:
image2 = torch.cat(
[
image2,
torch.ones(
*image2.shape[:-1],
max_channels - image2.shape[-1],
device=image2.device,
),
],
dim=-1,
)
# Add spacing if specified
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
max(image1.shape[1], image2.shape[1]),
spacing_width,
image1.shape[-1],
)
else:
spacing_shape = (
image1.shape[0],
spacing_width,
max(image1.shape[2], image2.shape[2]),
image1.shape[-1],
)
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
if isinstance(color_val, tuple):
for i, c in enumerate(color_val):
if i < spacing.shape[-1]:
spacing[..., i] = c
if spacing.shape[-1] == 4: # Add alpha
spacing[..., 3] = 1.0
else:
spacing[..., : min(3, spacing.shape[-1])] = color_val
if spacing.shape[-1] == 4:
spacing[..., 3] = 1.0
# Concatenate images
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
if spacing_width > 0:
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class SaveSVGNode:
"""
Save SVG files on disk.
@ -318,4 +499,5 @@ NODE_CLASS_MAPPINGS = {
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
}

View File

@ -2061,6 +2061,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
"ImageCrop": "Image Crop",
"ImageStitch": "Image Stitch",
"ImageBlend": "Image Blend",
"ImageBlur": "Image Blur",
"ImageQuantize": "Image Quantize",

View File

View File

@ -0,0 +1,240 @@
import torch
from unittest.mock import patch, MagicMock
# Mock nodes module to prevent CUDA initialization during import
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
with patch.dict('sys.modules', {'nodes': mock_nodes}):
from comfy_extras.nodes_images import ImageStitch
class TestImageStitch:
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
"""Helper to create test images with specific dimensions"""
return torch.rand(batch_size, height, width, channels)
def test_no_image2_passthrough(self):
"""Test that when image2 is None, image1 is returned unchanged"""
node = ImageStitch()
image1 = self.create_test_image()
result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1
assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self):
"""Test basic horizontal stitching to the right"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "right", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
def test_basic_horizontal_stitch_left(self):
"""Test basic horizontal stitching to the left"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "left", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
def test_basic_vertical_stitch_down(self):
"""Test basic vertical stitching downward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "down", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
def test_basic_vertical_stitch_up(self):
"""Test basic vertical stitching upward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "up", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
def test_size_matching_horizontal(self):
"""Test size matching for horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
result = node.stitch(image1, "right", True, 0, "white", image2)
# image2 should be resized to match image1's height (64) with preserved aspect ratio
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, 64, expected_width, 3)
def test_size_matching_vertical(self):
"""Test size matching for vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32)
result = node.stitch(image1, "down", True, 0, "white", image2)
# image2 should be resized to match image1's width (64) with preserved aspect ratio
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, expected_height, 64, 3)
def test_padding_for_mismatched_heights_horizontal(self):
"""Test padding when heights don't match in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=32)
image2 = self.create_test_image(height=48, width=24) # Shorter height
result = node.stitch(image1, "right", False, 0, "white", image2)
# Both images should be padded to height 64
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
def test_padding_for_mismatched_widths_vertical(self):
"""Test padding when widths don't match in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=64)
image2 = self.create_test_image(height=24, width=48) # Narrower width
result = node.stitch(image1, "down", False, 0, "white", image2)
# Both images should be padded to width 64
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
def test_spacing_horizontal(self):
"""Test spacing addition in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
spacing_width = 16
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
# Expected width: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 32, 72, 3)
def test_spacing_vertical(self):
"""Test spacing addition in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
spacing_width = 16
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
# Expected height: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 72, 32, 3)
def test_spacing_color_values(self):
"""Test that spacing colors are applied correctly"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Test white spacing
result_white = node.stitch(image1, "right", False, 16, "white", image2)
# Check that spacing region contains white values (close to 1.0)
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
assert torch.all(spacing_region >= 0.9) # Should be close to white
# Test black spacing
result_black = node.stitch(image1, "right", False, 16, "black", image2)
spacing_region = result_black[0][:, :, 32:48, :]
assert torch.all(spacing_region <= 0.1) # Should be close to black
def test_odd_spacing_width_made_even(self):
"""Test that odd spacing widths are made even"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Use odd spacing width
result = node.stitch(image1, "right", False, 15, "white", image2)
# Should be made even (16), so total width = 32 + 16 + 32 = 80
assert result[0].shape == (1, 32, 80, 3)
def test_batch_size_matching(self):
"""Test that different batch sizes are handled correctly"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=32, width=32)
image2 = self.create_test_image(batch_size=1, height=32, width=32)
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should match larger batch size
assert result[0].shape == (2, 32, 64, 3)
def test_channel_matching_rgb_to_rgba(self):
"""Test that channel differences are handled (RGB + alpha)"""
node = ImageStitch()
image1 = self.create_test_image(channels=3) # RGB
image2 = self.create_test_image(channels=4) # RGBA
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_channel_matching_rgba_to_rgb(self):
"""Test that channel differences are handled (RGBA + RGB)"""
node = ImageStitch()
image1 = self.create_test_image(channels=4) # RGBA
image2 = self.create_test_image(channels=3) # RGB
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_all_color_options(self):
"""Test all available color options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
colors = ["white", "black", "red", "green", "blue"]
for color in colors:
result = node.stitch(image1, "right", False, 16, color, image2)
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
def test_all_directions(self):
"""Test all direction options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
directions = ["right", "left", "up", "down"]
for direction in directions:
result = node.stitch(image1, direction, False, 0, "white", image2)
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
def test_batch_size_channel_spacing_integration(self):
"""Test integration of batch matching, channel matching, size matching, and spacings"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
result = node.stitch(image1, "right", True, 8, "red", image2)
# Should handle: batch matching, size matching, channel matching, spacing
assert result[0].shape[0] == 2 # Batch size matched
assert result[0].shape[-1] == 4 # Channels matched to max
assert result[0].shape[1] == 64 # Height from image1 (size matching)
# Width should be: 48 + 8 (spacing) + resized_image2_width
expected_image2_width = int(64 * (32/32)) # Resized to height 64
expected_total_width = 48 + 8 + expected_image2_width
assert result[0].shape[2] == expected_total_width