mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Add tests for hacks and non-string types
This commit is contained in:
parent
41c2dfb441
commit
ad43d4c729
@ -1,14 +1,13 @@
|
||||
import pytest
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
def test_exact_match():
|
||||
"""Test cases where types match exactly"""
|
||||
assert validate_node_input("STRING", "STRING")
|
||||
assert validate_node_input("STRING,INT", "STRING,INT")
|
||||
assert (
|
||||
validate_node_input("INT,STRING", "STRING,INT")
|
||||
) # Order shouldn't matter
|
||||
assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter
|
||||
|
||||
|
||||
def test_strict_mode():
|
||||
@ -67,6 +66,47 @@ def test_non_string():
|
||||
assert not validate_node_input(obj1, obj2)
|
||||
|
||||
|
||||
class NotEqualsOverrideTest(StrEnum):
|
||||
"""Test class for ``__ne__`` override."""
|
||||
|
||||
ANY = "*"
|
||||
LONGER_THAN_2 = "LONGER_THAN_2"
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if self == "LONGER_THAN_2":
|
||||
return not len(value) > 2
|
||||
raise TypeError("This is a class for unit tests only.")
|
||||
|
||||
|
||||
def test_ne_override():
|
||||
"""Test ``__ne__`` any override"""
|
||||
any = NotEqualsOverrideTest.ANY
|
||||
invalid_type = "INVALID_TYPE"
|
||||
obj = object()
|
||||
assert validate_node_input(any, any)
|
||||
assert validate_node_input(any, invalid_type)
|
||||
assert validate_node_input(any, obj)
|
||||
assert validate_node_input(any, {})
|
||||
assert validate_node_input(any, [])
|
||||
assert validate_node_input(any, [1, 2, 3])
|
||||
|
||||
|
||||
def test_ne_custom_override():
|
||||
"""Test ``__ne__`` custom override"""
|
||||
special = NotEqualsOverrideTest.LONGER_THAN_2
|
||||
|
||||
assert validate_node_input(special, special)
|
||||
assert validate_node_input(special, "*")
|
||||
assert validate_node_input(special, "INVALID_TYPE")
|
||||
assert validate_node_input(special, [1, 2, 3])
|
||||
|
||||
# Should fail
|
||||
assert not validate_node_input(special, [1, 2])
|
||||
assert not validate_node_input(special, "TY")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"received,input_type,strict,expected",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user