This commit is contained in:
Chenlei Hu 2025-04-11 08:18:20 -04:00 committed by GitHub
commit 11b23d5da6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 18 deletions

View File

@ -1,6 +1,9 @@
import nodes
from __future__ import annotations
from typing import Type, Literal
import nodes
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception):
pass
@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name, valid_inputs=None):
def get_input_info(
class_def: Type[ComfyNodeABC],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
"""Get the input type, category, and extra info for a given input name.
Arguments:
class_def: The class definition of the node.
input_name: The name of the input to get info for.
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
Returns:
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
"""
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)

View File

@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
if x not in inputs:
if input_category == "required":
@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated):
continue
val = inputs[x]
info = (type_input, extra_info)
info = (input_type, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated):
val = val["__value__"]
inputs[x] = val
if type_input == "INT":
if input_type == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
if input_type == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
if input_type == "STRING":
val = str(val)
inputs[x] = val
if type_input == "BOOLEAN":
if input_type == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"message": f"Failed to convert an input value to a {input_type} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if isinstance(type_input, list):
if val not in type_input:
if isinstance(input_type, list):
combo_options = input_type
if val not in combo_options:
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
if len(combo_options) > 20:
list_info = f"(list of length {len(combo_options)})"
input_config = None
else:
list_info = str(type_input)
list_info = str(combo_options)
error = {
"type": "value_not_in_list",