Compare commits

...

6 Commits

Author SHA1 Message Date
Akio Nishimura
45e5050e7d
Merge b367f21774 into 98bdca4cb2 2025-04-10 16:28:35 -04:00
Akio Nishimura
b367f21774 Correctly modify staticmethods 2025-04-03 16:43:18 +09:00
Akio Nishimura
8cc5326beb Modify builtin nodes 2025-04-02 08:46:10 +09:00
Akio Nishimura
01087a97b7 Merge branch 'master' of https://github.com/comfyanonymous/ComfyUI 2025-03-31 17:18:29 +09:00
Akio Nishimura
311e110b1f Merge branch 'master' of https://github.com/comfyanonymous/ComfyUI 2025-03-28 23:48:13 +09:00
Akio Nishimura
3e70e821c5 Modify node classes to make onTrigger usable 2025-03-28 23:45:29 +09:00

View File

@ -10,6 +10,7 @@ import math
import time import time
import random import random
import logging import logging
import inspect
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -2129,6 +2130,81 @@ def get_module_name(module_path: str) -> str:
return base_path return base_path
def extend_input_types(cls):
if not hasattr(cls, 'INPUT_TYPES'):
return cls
input_types = cls.INPUT_TYPES()
new_hidden_inputs = {
'onTrigger': ('*', {'default': None}),
}
if 'hidden' in input_types.keys():
for k, v in new_hidden_inputs.items():
if k not in input_types['hidden'].keys():
input_types['hidden'][k] = v
else:
input_types.update({'hidden': new_hidden_inputs})
def _INPUT_TYPES(s):
return input_types
cls.INPUT_TYPES = classmethod(_INPUT_TYPES)
return cls
def modify_function_params(cls):
if not hasattr(cls, 'FUNCTION'):
return cls
method_name = getattr(cls, 'FUNCTION')
is_static = isinstance(
inspect.getattr_static(cls, method_name),
staticmethod)
if is_static:
original_method = inspect.getattr_static(cls, method_name)
else:
original_method = getattr(cls, method_name)
sig = inspect.signature(original_method)
if 'onTrigger' in sig.parameters.keys():
return cls
# Define a wrapper function that uses the new signature
def _function(*args, **kwargs):
if 'onTrigger' in kwargs.keys():
del kwargs['onTrigger']
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
if is_static:
return original_method(*bound_args.args, **bound_args.kwargs)
else:
# For instance methods, ensure 'self' is correctly passed
self_instance = args[0]
method_args = bound_args.args[1:]
method_kwargs = bound_args.kwargs
return original_method(
self_instance, *method_args, **method_kwargs)
# Assign the new function directly to the class attribute
if is_static:
setattr(cls, method_name, staticmethod(_function))
else:
setattr(cls, method_name, _function)
return cls
for node_name, node_cls in NODE_CLASS_MAPPINGS.items():
node_cls = extend_input_types(node_cls)
node_cls = modify_function_params(node_cls)
NODE_CLASS_MAPPINGS[node_name] = node_cls
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = get_module_name(module_path) module_name = get_module_name(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):
@ -2161,6 +2237,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore: if name not in ignore:
node_cls = extend_input_types(node_cls)
node_cls = modify_function_params(node_cls)
NODE_CLASS_MAPPINGS[name] = node_cls NODE_CLASS_MAPPINGS[name] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: