Modify node classes to make onTrigger usable

This commit is contained in:
Akio Nishimura 2025-03-28 23:45:29 +09:00
parent 0956107170
commit 3e70e821c5

View File

@ -10,6 +10,7 @@ import math
import time import time
import random import random
import logging import logging
from inspect import signature
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -2122,6 +2123,49 @@ def get_module_name(module_path: str) -> str:
return base_path return base_path
def extend_input_types(cls):
if not hasattr(cls, 'INPUT_TYPES'):
return cls
input_types = cls.INPUT_TYPES()
new_hidden_inputs = {
'onTrigger': ('*', {'default': None}),
}
if 'hidden' in input_types.keys():
for k, v in new_hidden_inputs.items():
if k not in input_types['hidden'].keys():
input_types['hidden'][k] = v
else:
input_types.update({'hidden': new_hidden_inputs})
def _INPUT_TYPES(s):
return input_types
cls.INPUT_TYPES = classmethod(_INPUT_TYPES)
return cls
def modify_function_params(cls):
if not hasattr(cls, 'FUNCTION'):
return cls
function = getattr(cls, cls.FUNCTION)
sig = signature(function)
if 'onTrigger' not in sig.parameters.keys():
def _function(self, *args, **kwargs):
if 'onTrigger' in kwargs.keys():
del kwargs['onTrigger']
bound_args = sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
return function(*bound_args.args, **bound_args.kwargs)
setattr(cls, cls.FUNCTION, _function)
return cls
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = os.path.basename(module_path) module_name = os.path.basename(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):
@ -2150,6 +2194,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore: if name not in ignore:
node_cls = extend_input_types(node_cls)
node_cls = modify_function_params(node_cls)
NODE_CLASS_MAPPINGS[name] = node_cls NODE_CLASS_MAPPINGS[name] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: