This commit is contained in:
Akio Nishimura 2025-04-11 09:46:28 -04:00 committed by GitHub
commit 4e7b3dcc54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ import math
import time import time
import random import random
import logging import logging
import inspect
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -2129,6 +2130,81 @@ def get_module_name(module_path: str) -> str:
return base_path return base_path
def extend_input_types(cls):
if not hasattr(cls, 'INPUT_TYPES'):
return cls
input_types = cls.INPUT_TYPES()
new_hidden_inputs = {
'onTrigger': ('*', {'default': None}),
}
if 'hidden' in input_types.keys():
for k, v in new_hidden_inputs.items():
if k not in input_types['hidden'].keys():
input_types['hidden'][k] = v
else:
input_types.update({'hidden': new_hidden_inputs})
def _INPUT_TYPES(s):
return input_types
cls.INPUT_TYPES = classmethod(_INPUT_TYPES)
return cls
def modify_function_params(cls):
if not hasattr(cls, 'FUNCTION'):
return cls
method_name = getattr(cls, 'FUNCTION')
is_static = isinstance(
inspect.getattr_static(cls, method_name),
staticmethod)
if is_static:
original_method = inspect.getattr_static(cls, method_name)
else:
original_method = getattr(cls, method_name)
sig = inspect.signature(original_method)
if 'onTrigger' in sig.parameters.keys():
return cls
# Define a wrapper function that uses the new signature
def _function(*args, **kwargs):
if 'onTrigger' in kwargs.keys():
del kwargs['onTrigger']
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
if is_static:
return original_method(*bound_args.args, **bound_args.kwargs)
else:
# For instance methods, ensure 'self' is correctly passed
self_instance = args[0]
method_args = bound_args.args[1:]
method_kwargs = bound_args.kwargs
return original_method(
self_instance, *method_args, **method_kwargs)
# Assign the new function directly to the class attribute
if is_static:
setattr(cls, method_name, staticmethod(_function))
else:
setattr(cls, method_name, _function)
return cls
for node_name, node_cls in NODE_CLASS_MAPPINGS.items():
node_cls = extend_input_types(node_cls)
node_cls = modify_function_params(node_cls)
NODE_CLASS_MAPPINGS[node_name] = node_cls
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = get_module_name(module_path) module_name = get_module_name(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):
@ -2161,6 +2237,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore: if name not in ignore:
node_cls = extend_input_types(node_cls)
node_cls = modify_function_params(node_cls)
NODE_CLASS_MAPPINGS[name] = node_cls NODE_CLASS_MAPPINGS[name] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: