mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
4e402b11c6
* Reapply "Add union link connection type support (#5806)" (#5889)
This reverts commit bf9a90a145
.
* Fix union type breaks existing type workarounds
* Add non-string test
* Add tests for hacks and non-string types
* Support python versions lower than 3.11
120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
import pytest
|
|
from comfy_execution.validation import validate_node_input
|
|
|
|
|
|
def test_exact_match():
|
|
"""Test cases where types match exactly"""
|
|
assert validate_node_input("STRING", "STRING")
|
|
assert validate_node_input("STRING,INT", "STRING,INT")
|
|
assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter
|
|
|
|
|
|
def test_strict_mode():
|
|
"""Test strict mode validation"""
|
|
# Should pass - received type is subset of input type
|
|
assert validate_node_input("STRING", "STRING,INT", strict=True)
|
|
assert validate_node_input("INT", "STRING,INT", strict=True)
|
|
assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
|
|
|
|
# Should fail - received type is not subset of input type
|
|
assert not validate_node_input("STRING,INT", "STRING", strict=True)
|
|
assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
|
|
assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
|
|
|
|
|
|
def test_non_strict_mode():
|
|
"""Test non-strict mode validation (default behavior)"""
|
|
# Should pass - types have overlap
|
|
assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
|
|
assert validate_node_input("STRING,INT", "INT,BOOLEAN")
|
|
assert validate_node_input("STRING", "STRING,INT")
|
|
|
|
# Should fail - no overlap in types
|
|
assert not validate_node_input("BOOLEAN", "STRING,INT")
|
|
assert not validate_node_input("FLOAT", "STRING,INT")
|
|
assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
|
|
|
|
|
|
def test_whitespace_handling():
|
|
"""Test that whitespace is handled correctly"""
|
|
assert validate_node_input("STRING, INT", "STRING,INT")
|
|
assert validate_node_input("STRING,INT", "STRING, INT")
|
|
assert validate_node_input(" STRING , INT ", "STRING,INT")
|
|
assert validate_node_input("STRING,INT", " STRING , INT ")
|
|
|
|
|
|
def test_empty_strings():
|
|
"""Test behavior with empty strings"""
|
|
assert validate_node_input("", "")
|
|
assert not validate_node_input("STRING", "")
|
|
assert not validate_node_input("", "STRING")
|
|
|
|
|
|
def test_single_vs_multiple():
|
|
"""Test single type against multiple types"""
|
|
assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
|
|
assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
|
|
assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
|
|
|
|
|
|
def test_non_string():
|
|
"""Test non-string types"""
|
|
obj1 = object()
|
|
obj2 = object()
|
|
assert validate_node_input(obj1, obj1)
|
|
assert not validate_node_input(obj1, obj2)
|
|
|
|
|
|
class NotEqualsOverrideTest(str):
|
|
"""Test class for ``__ne__`` override."""
|
|
|
|
def __ne__(self, value: object) -> bool:
|
|
if self == "*" or value == "*":
|
|
return False
|
|
if self == "LONGER_THAN_2":
|
|
return not len(value) > 2
|
|
raise TypeError("This is a class for unit tests only.")
|
|
|
|
|
|
def test_ne_override():
|
|
"""Test ``__ne__`` any override"""
|
|
any = NotEqualsOverrideTest("*")
|
|
invalid_type = "INVALID_TYPE"
|
|
obj = object()
|
|
assert validate_node_input(any, any)
|
|
assert validate_node_input(any, invalid_type)
|
|
assert validate_node_input(any, obj)
|
|
assert validate_node_input(any, {})
|
|
assert validate_node_input(any, [])
|
|
assert validate_node_input(any, [1, 2, 3])
|
|
|
|
|
|
def test_ne_custom_override():
|
|
"""Test ``__ne__`` custom override"""
|
|
special = NotEqualsOverrideTest("LONGER_THAN_2")
|
|
|
|
assert validate_node_input(special, special)
|
|
assert validate_node_input(special, "*")
|
|
assert validate_node_input(special, "INVALID_TYPE")
|
|
assert validate_node_input(special, [1, 2, 3])
|
|
|
|
# Should fail
|
|
assert not validate_node_input(special, [1, 2])
|
|
assert not validate_node_input(special, "TY")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"received,input_type,strict,expected",
|
|
[
|
|
("STRING", "STRING", False, True),
|
|
("STRING,INT", "STRING,INT", False, True),
|
|
("STRING", "STRING,INT", True, True),
|
|
("STRING,INT", "STRING", True, False),
|
|
("BOOLEAN", "STRING,INT", False, False),
|
|
("STRING,BOOLEAN", "STRING,INT", False, True),
|
|
],
|
|
)
|
|
def test_parametrized_cases(received, input_type, strict, expected):
|
|
"""Parametrized test cases for various scenarios"""
|
|
assert validate_node_input(received, input_type, strict) == expected
|