Small refactor of custom node loading code.

This commit is contained in:
comfyanonymous 2023-03-11 12:49:41 -05:00
parent 1f717903bc
commit 8c4ccb55d1

View File

@ -948,17 +948,11 @@ NODE_CLASS_MAPPINGS = {
"VAEDecodeTiled": VAEDecodeTiled,
}
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
def load_custom_nodes():
possible_modules = os.listdir(CUSTOM_NODE_PATH)
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
for possible_module in possible_modules:
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
module_name = possible_module
def load_custom_node(module_path):
module_name = os.path.basename(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]
try:
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
@ -970,9 +964,20 @@ def load_custom_nodes():
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
else:
print(f"Skip {possible_module} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {possible_module} module for custom nodes:", e)
print(f"Cannot import {module_path} module for custom nodes:", e)
def load_custom_nodes():
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
possible_modules = os.listdir(CUSTOM_NODE_PATH)
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
for possible_module in possible_modules:
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
load_custom_nodes()