From f03ece18f26b56d61477ada2243865dde4822605 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Sun, 2 Mar 2025 11:58:04 -0800 Subject: [PATCH] Add remaining patch --- comfy/ldm/cosmos/blocks.py | 19 +++++++++++-------- comfy/ldm/modules/attention.py | 6 +++--- comfy/model_patcher.py | 21 ++++++++++++--------- comfy/sd.py | 23 ++++++++++++++++++++++- execution.py | 29 ++++++++++++++++++++--------- folder_paths.py | 29 +++++++++++++++++++++++++++++ nodes.py | 1 + 7 files changed, 98 insertions(+), 30 deletions(-) diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 84fd6d83..c8427e4a 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -797,12 +797,15 @@ class GeneralDITTransformerBlock(nn.Module): adaln_lora_B_3D: Optional[torch.Tensor] = None, ) -> torch.Tensor: for block in self.blocks: - x = block( - x, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - ) + if self.training: + x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False) + else: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) return x diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ede50646..0149f4c4 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -750,7 +750,7 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: @@ -790,12 +790,12 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: - x += x_skip + x = x_skip + x return x diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b7cb12df..58923051 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -17,23 +17,26 @@ """ from __future__ import annotations -from typing import Optional, Callable -import torch + +import collections import copy import inspect import logging -import uuid -import collections import math +import uuid +from typing import Callable, Optional + +import torch -import comfy.utils import comfy.float -import comfy.model_management -import comfy.lora import comfy.hooks +import comfy.lora +import comfy.model_management import comfy.patcher_extension -from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection +import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP + def string_to_seed(data): crc = 0xFFFFFFFF @@ -263,7 +266,7 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter - + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n.patches = {} diff --git a/comfy/sd.py b/comfy/sd.py index d096f496..decc71f3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -986,7 +986,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): + """ + Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. + + Args: + sd (dict): State dictionary containing model weights and configuration + model_options (dict, optional): Additional options for model loading. Supports: + - dtype: Override model data type + - custom_operations: Custom model operations + - fp8_optimizations: Enable FP8 optimizations + + Returns: + ModelPatcher: A wrapped model instance that handles device management and weight loading. + Returns None if the model configuration cannot be detected. + + The function: + 1. Detects and handles different model formats (regular, diffusers, mmdit) + 2. Configures model dtype based on parameters and device capabilities + 3. Handles weight conversion and device placement + 4. Manages model optimization settings + 5. Loads weights and returns a device-managed model instance + """ dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diff --git a/execution.py b/execution.py index fcb4f6f4..bfcaaa72 100644 --- a/execution.py +++ b/execution.py @@ -1,23 +1,34 @@ -import sys import copy -import logging -import threading import heapq +import inspect +import logging +import sys +import threading import time import traceback from enum import Enum -import inspect from typing import List, Literal, NamedTuple, Optional import torch -import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker -from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +import nodes +from comfy_execution.caching import ( + CacheKeySetID, + CacheKeySetInputSignature, + HierarchicalCache, + LRUCache, +) +from comfy_execution.graph import ( + DynamicPrompt, + ExecutionBlocker, + ExecutionList, + get_input_info, +) +from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input + class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 @@ -573,7 +584,7 @@ def validate_inputs(prompt, item, validated): val = inputs[x] info = (type_input, extra_info) if isinstance(val, list): - if len(val) != 2: + if len(val) != 2 and not extra_info.get("allow_batch", False): error = { "type": "bad_linked_input", "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", diff --git a/folder_paths.py b/folder_paths.py index 72c70f59..915bead5 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -272,6 +272,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) def get_full_path(folder_name: str, filename: str) -> str | None: + """ + Get the full path of a file in a folder, has to be a file + """ global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -289,6 +292,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path_or_raise(folder_name: str, filename: str) -> str: + """ + Get the full path of a file in a folder, has to be a file + """ full_path = get_full_path(folder_name, filename) if full_path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") @@ -390,3 +396,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im os.makedirs(full_output_folder, exist_ok=True) counter = 1 return full_output_folder, filename, counter, subfolder, filename_prefix + +def get_input_subfolders() -> list[str]: + """Returns a list of all subfolder paths in the input directory, recursively. + + Returns: + List of folder paths relative to the input directory, excluding the root directory + """ + input_dir = get_input_directory() + folders = [] + + try: + if not os.path.exists(input_dir): + return [] + + for root, dirs, _ in os.walk(input_dir): + rel_path = os.path.relpath(root, input_dir) + if rel_path != ".": # Only include non-root directories + # Normalize path separators to forward slashes + folders.append(rel_path.replace(os.sep, '/')) + + return sorted(folders) + except FileNotFoundError: + return [] diff --git a/nodes.py b/nodes.py index 272c2a25..1055371f 100644 --- a/nodes.py +++ b/nodes.py @@ -2229,6 +2229,7 @@ def init_builtin_extra_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_train.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py",