This commit is contained in:
Chenlei Hu 2025-04-11 08:18:20 -04:00 committed by GitHub
commit 11b23d5da6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 18 deletions

View File

@ -1,6 +1,9 @@
import nodes from __future__ import annotations
from typing import Type, Literal
import nodes
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception): class DependencyCycleError(Exception):
pass pass
@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self): def get_original_prompt(self):
return self.original_prompt return self.original_prompt
def get_input_info(class_def, input_name, valid_inputs=None): def get_input_info(
class_def: Type[ComfyNodeABC],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
"""Get the input type, category, and extra info for a given input name.
Arguments:
class_def: The class definition of the node.
input_name: The name of the input to get info for.
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
Returns:
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
"""
valid_inputs = valid_inputs or class_def.INPUT_TYPES() valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None input_info = None
input_category = None input_category = None
@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes: if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name) _, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id) node_ids.append(from_node_id)

View File

@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {} missing_keys = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing(): def mark_missing():
missing_keys[x] = True missing_keys[x] = True
input_data_all[x] = (None,) input_data_all[x] = (None,)
@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated):
received_types = {} received_types = {}
for x in valid_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None assert extra_info is not None
if x not in inputs: if x not in inputs:
if input_category == "required": if input_category == "required":
@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated):
continue continue
val = inputs[x] val = inputs[x]
info = (type_input, extra_info) info = (input_type, extra_info)
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
error = { error = {
@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]] received_type = r[val[1]]
received_types[x] = received_type received_types[x] = received_type
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
error = { error = {
"type": "return_type_mismatch", "type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes", "message": "Return type mismatch between linked nodes",
@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated):
val = val["__value__"] val = val["__value__"]
inputs[x] = val inputs[x] = val
if type_input == "INT": if input_type == "INT":
val = int(val) val = int(val)
inputs[x] = val inputs[x] = val
if type_input == "FLOAT": if input_type == "FLOAT":
val = float(val) val = float(val)
inputs[x] = val inputs[x] = val
if type_input == "STRING": if input_type == "STRING":
val = str(val) val = str(val)
inputs[x] = val inputs[x] = val
if type_input == "BOOLEAN": if input_type == "BOOLEAN":
val = bool(val) val = bool(val)
inputs[x] = val inputs[x] = val
except Exception as ex: except Exception as ex:
error = { error = {
"type": "invalid_input_type", "type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value", "message": f"Failed to convert an input value to a {input_type} value",
"details": f"{x}, {val}, {ex}", "details": f"{x}, {val}, {ex}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if isinstance(type_input, list): if isinstance(input_type, list):
if val not in type_input: combo_options = input_type
if val not in combo_options:
input_config = info input_config = info
list_info = "" list_info = ""
# Don't send back gigantic lists like if they're lots of # Don't send back gigantic lists like if they're lots of
# scanned model filepaths # scanned model filepaths
if len(type_input) > 20: if len(combo_options) > 20:
list_info = f"(list of length {len(type_input)})" list_info = f"(list of length {len(combo_options)})"
input_config = None input_config = None
else: else:
list_info = str(type_input) list_info = str(combo_options)
error = { error = {
"type": "value_not_in_list", "type": "value_not_in_list",