mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
This reverts commit 8d4e06324f
.
This commit is contained in:
parent
c1b92b719d
commit
bf9a90a145
@ -1,32 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
|
||||||
def validate_node_input(
|
|
||||||
received_type: str, input_type: str, strict: bool = False
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
received_type and input_type are both strings of the form "T1,T2,...".
|
|
||||||
|
|
||||||
If strict is True, the input_type must contain the received_type.
|
|
||||||
For example, if received_type is "STRING" and input_type is "STRING,INT",
|
|
||||||
this will return True. But if received_type is "STRING,INT" and input_type is
|
|
||||||
"INT", this will return False.
|
|
||||||
|
|
||||||
If strict is False, the input_type must have overlap with the received_type.
|
|
||||||
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
|
|
||||||
this will return True.
|
|
||||||
"""
|
|
||||||
# If the types are exactly the same, we can return immediately
|
|
||||||
if received_type == input_type:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Split the type strings into sets for comparison
|
|
||||||
received_types = set(t.strip() for t in received_type.split(","))
|
|
||||||
input_types = set(t.strip() for t in input_type.split(","))
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
# In strict mode, all received types must be in the input types
|
|
||||||
return received_types.issubset(input_types)
|
|
||||||
else:
|
|
||||||
# In non-strict mode, there must be at least one type in common
|
|
||||||
return len(received_types.intersection(input_types)) > 0
|
|
@ -16,7 +16,6 @@ import comfy.model_management
|
|||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -528,6 +527,7 @@ class PromptExecutor:
|
|||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated):
|
def validate_inputs(prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
if unique_id in validated:
|
if unique_id in validated:
|
||||||
@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
details = f"{x}, {received_type} != {type_input}"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from comfy_execution.validation import validate_node_input
|
|
||||||
|
|
||||||
|
|
||||||
def test_exact_match():
|
|
||||||
"""Test cases where types match exactly"""
|
|
||||||
assert validate_node_input("STRING", "STRING")
|
|
||||||
assert validate_node_input("STRING,INT", "STRING,INT")
|
|
||||||
assert (
|
|
||||||
validate_node_input("INT,STRING", "STRING,INT")
|
|
||||||
) # Order shouldn't matter
|
|
||||||
|
|
||||||
|
|
||||||
def test_strict_mode():
|
|
||||||
"""Test strict mode validation"""
|
|
||||||
# Should pass - received type is subset of input type
|
|
||||||
assert validate_node_input("STRING", "STRING,INT", strict=True)
|
|
||||||
assert validate_node_input("INT", "STRING,INT", strict=True)
|
|
||||||
assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
|
|
||||||
|
|
||||||
# Should fail - received type is not subset of input type
|
|
||||||
assert not validate_node_input("STRING,INT", "STRING", strict=True)
|
|
||||||
assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
|
|
||||||
assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_strict_mode():
|
|
||||||
"""Test non-strict mode validation (default behavior)"""
|
|
||||||
# Should pass - types have overlap
|
|
||||||
assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
|
|
||||||
assert validate_node_input("STRING,INT", "INT,BOOLEAN")
|
|
||||||
assert validate_node_input("STRING", "STRING,INT")
|
|
||||||
|
|
||||||
# Should fail - no overlap in types
|
|
||||||
assert not validate_node_input("BOOLEAN", "STRING,INT")
|
|
||||||
assert not validate_node_input("FLOAT", "STRING,INT")
|
|
||||||
assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
|
|
||||||
|
|
||||||
|
|
||||||
def test_whitespace_handling():
|
|
||||||
"""Test that whitespace is handled correctly"""
|
|
||||||
assert validate_node_input("STRING, INT", "STRING,INT")
|
|
||||||
assert validate_node_input("STRING,INT", "STRING, INT")
|
|
||||||
assert validate_node_input(" STRING , INT ", "STRING,INT")
|
|
||||||
assert validate_node_input("STRING,INT", " STRING , INT ")
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_strings():
|
|
||||||
"""Test behavior with empty strings"""
|
|
||||||
assert validate_node_input("", "")
|
|
||||||
assert not validate_node_input("STRING", "")
|
|
||||||
assert not validate_node_input("", "STRING")
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_vs_multiple():
|
|
||||||
"""Test single type against multiple types"""
|
|
||||||
assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
|
|
||||||
assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
|
|
||||||
assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"received,input_type,strict,expected",
|
|
||||||
[
|
|
||||||
("STRING", "STRING", False, True),
|
|
||||||
("STRING,INT", "STRING,INT", False, True),
|
|
||||||
("STRING", "STRING,INT", True, True),
|
|
||||||
("STRING,INT", "STRING", True, False),
|
|
||||||
("BOOLEAN", "STRING,INT", False, False),
|
|
||||||
("STRING,BOOLEAN", "STRING,INT", False, True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parametrized_cases(received, input_type, strict, expected):
|
|
||||||
"""Parametrized test cases for various scenarios"""
|
|
||||||
assert validate_node_input(received, input_type, strict) == expected
|
|
Loading…
Reference in New Issue
Block a user