mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-12 18:33:35 +00:00
Merge b367f21774
into 22ad513c72
This commit is contained in:
commit
4e7b3dcc54
78
nodes.py
78
nodes.py
@ -10,6 +10,7 @@ import math
|
||||
import time
|
||||
import random
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
@ -2129,6 +2130,81 @@ def get_module_name(module_path: str) -> str:
|
||||
return base_path
|
||||
|
||||
|
||||
def extend_input_types(cls):
|
||||
if not hasattr(cls, 'INPUT_TYPES'):
|
||||
return cls
|
||||
|
||||
input_types = cls.INPUT_TYPES()
|
||||
|
||||
new_hidden_inputs = {
|
||||
'onTrigger': ('*', {'default': None}),
|
||||
}
|
||||
|
||||
if 'hidden' in input_types.keys():
|
||||
for k, v in new_hidden_inputs.items():
|
||||
if k not in input_types['hidden'].keys():
|
||||
input_types['hidden'][k] = v
|
||||
else:
|
||||
input_types.update({'hidden': new_hidden_inputs})
|
||||
|
||||
def _INPUT_TYPES(s):
|
||||
return input_types
|
||||
|
||||
cls.INPUT_TYPES = classmethod(_INPUT_TYPES)
|
||||
return cls
|
||||
|
||||
|
||||
def modify_function_params(cls):
|
||||
if not hasattr(cls, 'FUNCTION'):
|
||||
return cls
|
||||
|
||||
method_name = getattr(cls, 'FUNCTION')
|
||||
is_static = isinstance(
|
||||
inspect.getattr_static(cls, method_name),
|
||||
staticmethod)
|
||||
|
||||
if is_static:
|
||||
original_method = inspect.getattr_static(cls, method_name)
|
||||
else:
|
||||
original_method = getattr(cls, method_name)
|
||||
|
||||
sig = inspect.signature(original_method)
|
||||
|
||||
if 'onTrigger' in sig.parameters.keys():
|
||||
return cls
|
||||
|
||||
# Define a wrapper function that uses the new signature
|
||||
def _function(*args, **kwargs):
|
||||
if 'onTrigger' in kwargs.keys():
|
||||
del kwargs['onTrigger']
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
if is_static:
|
||||
return original_method(*bound_args.args, **bound_args.kwargs)
|
||||
else:
|
||||
# For instance methods, ensure 'self' is correctly passed
|
||||
self_instance = args[0]
|
||||
method_args = bound_args.args[1:]
|
||||
method_kwargs = bound_args.kwargs
|
||||
return original_method(
|
||||
self_instance, *method_args, **method_kwargs)
|
||||
|
||||
# Assign the new function directly to the class attribute
|
||||
if is_static:
|
||||
setattr(cls, method_name, staticmethod(_function))
|
||||
else:
|
||||
setattr(cls, method_name, _function)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
for node_name, node_cls in NODE_CLASS_MAPPINGS.items():
|
||||
node_cls = extend_input_types(node_cls)
|
||||
node_cls = modify_function_params(node_cls)
|
||||
NODE_CLASS_MAPPINGS[node_name] = node_cls
|
||||
|
||||
|
||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = get_module_name(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
@ -2161,6 +2237,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
node_cls = extend_input_types(node_cls)
|
||||
node_cls = modify_function_params(node_cls)
|
||||
NODE_CLASS_MAPPINGS[name] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user