mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-03 10:02:09 +08:00
Compare commits
20 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
856448060c | ||
![]() |
312d511630 | ||
![]() |
4f4f1c642a | ||
![]() |
010954d277 | ||
![]() |
6d46bb4b4c | ||
![]() |
67f57c5bcc | ||
![]() |
fd943c928f | ||
![]() |
d3bd983b91 | ||
![]() |
fb4754624d | ||
![]() |
180db6753f | ||
![]() |
d062fcc5c0 | ||
![]() |
456abad834 | ||
![]() |
19e45e9b0e | ||
![]() |
97f23b81f3 | ||
![]() |
08b7cc7506 | ||
![]() |
6c319cbb4e | ||
![]() |
df1aebe52e | ||
![]() |
704fc78854 | ||
![]() |
1d9fee79fd | ||
![]() |
aeba0b3a26 |
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -15,6 +15,14 @@ body:
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -11,6 +11,14 @@ body:
|
||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
26
CODEOWNERS
26
CODEOWNERS
@ -5,20 +5,20 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
|
@ -205,6 +205,19 @@ comfyui-workflow-templates is not installed.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
"""Get the path to embedded documentation"""
|
||||
try:
|
||||
import comfyui_embedded_docs
|
||||
|
||||
return str(
|
||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||
)
|
||||
except ImportError:
|
||||
logging.info("comfyui-embedded-docs package not found")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
|
@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
|
||||
|
||||
def size(self):
|
||||
return [1]
|
||||
|
||||
|
||||
class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
def can_concat(self, other):
|
||||
if len(self.cond) != len(other.cond):
|
||||
return False
|
||||
for i in range(len(self.cond)):
|
||||
if self.cond[i].shape != other.cond[i].shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
out = []
|
||||
for i in range(len(self.cond)):
|
||||
o = [self.cond[i]]
|
||||
for x in others:
|
||||
o.append(x.cond[i])
|
||||
out.append(torch.cat(o))
|
||||
|
||||
return out
|
||||
|
||||
def size(self): # hackish implementation to make the mem estimation work
|
||||
o = 0
|
||||
c = 1
|
||||
for c in self.cond:
|
||||
size = c.size()
|
||||
o += math.prod(size)
|
||||
if len(size) > 1:
|
||||
c = size[1]
|
||||
|
||||
return [1, c, o // c]
|
||||
|
@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@ -178,6 +176,6 @@ class LastLayer(nn.Module):
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
@ -102,6 +102,13 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
return extra
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
@ -165,9 +172,14 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = convert_tensor(extra, dtype)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
|
@ -297,8 +297,13 @@ except:
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||
|
@ -125,22 +125,6 @@ class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
|
||||
class BFLFluxKontextMaxGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
|
@ -327,7 +327,9 @@ class ApiClient:
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
url = urljoin(self.base_url, path)
|
||||
# Use urljoin but ensure path is relative to avoid absolute path behavior
|
||||
relative_path = path.lstrip('/')
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
|
@ -272,7 +272,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
|
||||
class FluxKontextProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext Pro via api based on prompt and resolution.
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
@ -321,7 +321,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"default": 1234,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
@ -366,6 +366,8 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -382,7 +384,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-kontext-pro/generate",
|
||||
path=self.BFL_PATH,
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
@ -411,146 +413,15 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
class FluxKontextMaxImageNode(ComfyNodeABC):
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext Max via api based on prompt and resolution.
|
||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation - specify what and how to edit.",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 3.0,
|
||||
"min": 0.1,
|
||||
"max": 99.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 150,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"input_image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor]=None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-kontext-max/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxKontextProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
else convert_image_to_base64(input_image)
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
class FluxProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
@ -1208,8 +1079,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||
"FluxKontextProImageNode": "Flux.1 Kontext Pro Image",
|
||||
"FluxKontextMaxImageNode": "Flux.1 Kontext Max Image",
|
||||
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
|
||||
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
|
||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||
"FluxProFillNode": "Flux.1 Fill Image",
|
||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||
|
@ -14,8 +14,10 @@ import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
from comfy.comfy_types import FileLocator, IO
|
||||
from server import PromptServer
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
@ -229,6 +231,186 @@ class SVG:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
max_batch = max(image1.shape[0], image2.shape[0])
|
||||
if image1.shape[0] < max_batch:
|
||||
image1 = torch.cat(
|
||||
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||
)
|
||||
if image2.shape[0] < max_batch:
|
||||
image2 = torch.cat(
|
||||
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||
)
|
||||
|
||||
# Match image sizes if requested
|
||||
if match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
aspect_ratio = w2 / h2
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||
else: # up, down
|
||||
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||
|
||||
image2 = comfy.utils.common_upscale(
|
||||
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||
).movedim(1, -1)
|
||||
|
||||
# When not matching sizes, pad to align non-concat dimensions
|
||||
if not match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
# For horizontal concat, pad heights to match
|
||||
if h1 != h2:
|
||||
target_h = max(h1, h2)
|
||||
if h1 < target_h:
|
||||
pad_h = target_h - h1
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
if h2 < target_h:
|
||||
pad_h = target_h - h2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
else: # up, down
|
||||
# For vertical concat, pad widths to match
|
||||
if w1 != w2:
|
||||
target_w = max(w1, w2)
|
||||
if w1 < target_w:
|
||||
pad_w = target_w - w1
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
if w2 < target_w:
|
||||
pad_w = target_w - w2
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
|
||||
# Ensure same number of channels
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||
if image1.shape[-1] < max_channels:
|
||||
image1 = torch.cat(
|
||||
[
|
||||
image1,
|
||||
torch.ones(
|
||||
*image1.shape[:-1],
|
||||
max_channels - image1.shape[-1],
|
||||
device=image1.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if image2.shape[-1] < max_channels:
|
||||
image2 = torch.cat(
|
||||
[
|
||||
image2,
|
||||
torch.ones(
|
||||
*image2.shape[:-1],
|
||||
max_channels - image2.shape[-1],
|
||||
device=image2.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Add spacing if specified
|
||||
if spacing_width > 0:
|
||||
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||
|
||||
color_map = {
|
||||
"white": 1.0,
|
||||
"black": 0.0,
|
||||
"red": (1.0, 0.0, 0.0),
|
||||
"green": (0.0, 1.0, 0.0),
|
||||
"blue": (0.0, 0.0, 1.0),
|
||||
}
|
||||
color_val = color_map[spacing_color]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
max(image1.shape[1], image2.shape[1]),
|
||||
spacing_width,
|
||||
image1.shape[-1],
|
||||
)
|
||||
else:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
spacing_width,
|
||||
max(image1.shape[2], image2.shape[2]),
|
||||
image1.shape[-1],
|
||||
)
|
||||
|
||||
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||
if isinstance(color_val, tuple):
|
||||
for i, c in enumerate(color_val):
|
||||
if i < spacing.shape[-1]:
|
||||
spacing[..., i] = c
|
||||
if spacing.shape[-1] == 4: # Add alpha
|
||||
spacing[..., 3] = 1.0
|
||||
else:
|
||||
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||
if spacing.shape[-1] == 4:
|
||||
spacing[..., 3] = 1.0
|
||||
|
||||
# Concatenate images
|
||||
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||
if spacing_width > 0:
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
@ -310,6 +492,36 @@ class SaveSVGNode:
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
class GetImageSize:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT)
|
||||
RETURN_NAMES = ("width", "height")
|
||||
FUNCTION = "get_size"
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||
|
||||
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
|
||||
# Send progress text to display size on the node
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}", unique_id)
|
||||
|
||||
return width, height
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
@ -318,4 +530,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
"GetImageSize": GetImageSize,
|
||||
}
|
||||
|
@ -296,6 +296,41 @@ class RegexExtract():
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexReplace():
|
||||
DESCRIPTION = "Find and replace text using regex patterns."
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||
return result,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
@ -306,7 +341,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract
|
||||
"RegexExtract": RegexExtract,
|
||||
"RegexReplace": RegexReplace,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -319,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract"
|
||||
"RegexExtract": "Regex Extract",
|
||||
"RegexReplace": "Regex Replace",
|
||||
}
|
||||
|
2
nodes.py
2
nodes.py
@ -2061,11 +2061,13 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||
"ImageBatch": "Batch Images",
|
||||
"ImageCrop": "Image Crop",
|
||||
"ImageStitch": "Image Stitch",
|
||||
"ImageBlend": "Image Blend",
|
||||
"ImageBlur": "Image Blur",
|
||||
"ImageQuantize": "Image Quantize",
|
||||
"ImageSharpen": "Image Sharpen",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# _for_testing
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
|
@ -1,5 +1,6 @@
|
||||
comfyui-frontend-package==1.20.7
|
||||
comfyui-workflow-templates==0.1.22
|
||||
comfyui-frontend-package==1.21.6
|
||||
comfyui-workflow-templates==0.1.25
|
||||
comfyui-embedded-docs==0.2.0
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
14
server.py
14
server.py
@ -390,7 +390,7 @@ class PromptServer():
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
@ -476,9 +476,8 @@ class PromptServer():
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
# For security, force certain extensions to download instead of display
|
||||
file_extension = os.path.splitext(filename)[1].lower()
|
||||
if file_extension in {'.html', '.htm', '.js', '.css'}:
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
|
||||
return web.FileResponse(
|
||||
@ -746,6 +745,13 @@ class PromptServer():
|
||||
web.static('/templates', workflow_templates_path)
|
||||
])
|
||||
|
||||
# Serve embedded documentation from the package
|
||||
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||
if embedded_docs_path:
|
||||
self.app.add_routes([
|
||||
web.static('/docs', embedded_docs_path)
|
||||
])
|
||||
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
])
|
||||
|
0
tests-unit/comfy_extras_test/__init__.py
Normal file
0
tests-unit/comfy_extras_test/__init__.py
Normal file
243
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
243
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
@ -0,0 +1,243 @@
|
||||
import torch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Mock nodes module to prevent CUDA initialization during import
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
|
||||
# Mock server module for PromptServer
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
|
||||
from comfy_extras.nodes_images import ImageStitch
|
||||
|
||||
|
||||
class TestImageStitch:
|
||||
|
||||
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
||||
"""Helper to create test images with specific dimensions"""
|
||||
return torch.rand(batch_size, height, width, channels)
|
||||
|
||||
def test_no_image2_passthrough(self):
|
||||
"""Test that when image2 is None, image1 is returned unchanged"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image()
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||
|
||||
assert len(result) == 1
|
||||
assert torch.equal(result[0], image1)
|
||||
|
||||
def test_basic_horizontal_stitch_right(self):
|
||||
"""Test basic horizontal stitching to the right"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
||||
|
||||
def test_basic_horizontal_stitch_left(self):
|
||||
"""Test basic horizontal stitching to the left"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "left", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
||||
|
||||
def test_basic_vertical_stitch_down(self):
|
||||
"""Test basic vertical stitching downward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
||||
|
||||
def test_basic_vertical_stitch_up(self):
|
||||
"""Test basic vertical stitching upward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "up", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
||||
|
||||
def test_size_matching_horizontal(self):
|
||||
"""Test size matching for horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
||||
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, 64, expected_width, 3)
|
||||
|
||||
def test_size_matching_vertical(self):
|
||||
"""Test size matching for vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
||||
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, expected_height, 64, 3)
|
||||
|
||||
def test_padding_for_mismatched_heights_horizontal(self):
|
||||
"""Test padding when heights don't match in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=32)
|
||||
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to height 64
|
||||
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
||||
|
||||
def test_padding_for_mismatched_widths_vertical(self):
|
||||
"""Test padding when widths don't match in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=64)
|
||||
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to width 64
|
||||
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
||||
|
||||
def test_spacing_horizontal(self):
|
||||
"""Test spacing addition in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected width: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 32, 72, 3)
|
||||
|
||||
def test_spacing_vertical(self):
|
||||
"""Test spacing addition in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected height: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 72, 32, 3)
|
||||
|
||||
def test_spacing_color_values(self):
|
||||
"""Test that spacing colors are applied correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Test white spacing
|
||||
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
||||
# Check that spacing region contains white values (close to 1.0)
|
||||
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
||||
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
||||
|
||||
# Test black spacing
|
||||
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
||||
spacing_region = result_black[0][:, :, 32:48, :]
|
||||
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
||||
|
||||
def test_odd_spacing_width_made_even(self):
|
||||
"""Test that odd spacing widths are made even"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Use odd spacing width
|
||||
result = node.stitch(image1, "right", False, 15, "white", image2)
|
||||
|
||||
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
||||
assert result[0].shape == (1, 32, 80, 3)
|
||||
|
||||
def test_batch_size_matching(self):
|
||||
"""Test that different batch sizes are handled correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should match larger batch size
|
||||
assert result[0].shape == (2, 32, 64, 3)
|
||||
|
||||
def test_channel_matching_rgb_to_rgba(self):
|
||||
"""Test that channel differences are handled (RGB + alpha)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=3) # RGB
|
||||
image2 = self.create_test_image(channels=4) # RGBA
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_channel_matching_rgba_to_rgb(self):
|
||||
"""Test that channel differences are handled (RGBA + RGB)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=4) # RGBA
|
||||
image2 = self.create_test_image(channels=3) # RGB
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_all_color_options(self):
|
||||
"""Test all available color options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
colors = ["white", "black", "red", "green", "blue"]
|
||||
|
||||
for color in colors:
|
||||
result = node.stitch(image1, "right", False, 16, color, image2)
|
||||
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
||||
|
||||
def test_all_directions(self):
|
||||
"""Test all direction options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
directions = ["right", "left", "up", "down"]
|
||||
|
||||
for direction in directions:
|
||||
result = node.stitch(image1, direction, False, 0, "white", image2)
|
||||
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
||||
|
||||
def test_batch_size_channel_spacing_integration(self):
|
||||
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
||||
|
||||
result = node.stitch(image1, "right", True, 8, "red", image2)
|
||||
|
||||
# Should handle: batch matching, size matching, channel matching, spacing
|
||||
assert result[0].shape[0] == 2 # Batch size matched
|
||||
assert result[0].shape[-1] == 4 # Channels matched to max
|
||||
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
||||
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
||||
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
||||
expected_total_width = 48 + 8 + expected_image2_width
|
||||
assert result[0].shape[2] == expected_total_width
|
||||
|
Loading…
x
Reference in New Issue
Block a user