mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-19 05:57:04 +08:00
Add remaining patch
This commit is contained in:
parent
2cd3c8a2fb
commit
f03ece18f2
@ -797,6 +797,9 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
|
if self.training:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False)
|
||||||
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
emb_B_D,
|
emb_B_D,
|
||||||
|
@ -750,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x = n + x
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
@ -790,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x = n + x
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x_skip = x
|
x_skip = x
|
||||||
x = self.ff(self.norm3(x))
|
x = self.ff(self.norm3(x))
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x += x_skip
|
x = x_skip + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -17,23 +17,26 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Callable
|
|
||||||
import torch
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
import collections
|
|
||||||
import math
|
import math
|
||||||
|
import uuid
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
|
||||||
import comfy.lora
|
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
|
import comfy.lora
|
||||||
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
|
23
comfy/sd.py
23
comfy/sd.py
@ -986,7 +986,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||||
|
"""
|
||||||
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sd (dict): State dictionary containing model weights and configuration
|
||||||
|
model_options (dict, optional): Additional options for model loading. Supports:
|
||||||
|
- dtype: Override model data type
|
||||||
|
- custom_operations: Custom model operations
|
||||||
|
- fp8_optimizations: Enable FP8 optimizations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||||
|
Returns None if the model configuration cannot be detected.
|
||||||
|
|
||||||
|
The function:
|
||||||
|
1. Detects and handles different model formats (regular, diffusers, mmdit)
|
||||||
|
2. Configures model dtype based on parameters and device capabilities
|
||||||
|
3. Handles weight conversion and device placement
|
||||||
|
4. Manages model optimization settings
|
||||||
|
5. Loads weights and returns a device-managed model instance
|
||||||
|
"""
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
|
29
execution.py
29
execution.py
@ -1,23 +1,34 @@
|
|||||||
import sys
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import heapq
|
import heapq
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import inspect
|
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.caching import (
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
CacheKeySetID,
|
||||||
|
CacheKeySetInputSignature,
|
||||||
|
HierarchicalCache,
|
||||||
|
LRUCache,
|
||||||
|
)
|
||||||
|
from comfy_execution.graph import (
|
||||||
|
DynamicPrompt,
|
||||||
|
ExecutionBlocker,
|
||||||
|
ExecutionList,
|
||||||
|
get_input_info,
|
||||||
|
)
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
FAILURE = 1
|
FAILURE = 1
|
||||||
@ -573,7 +584,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (type_input, extra_info)
|
info = (type_input, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2 and not extra_info.get("allow_batch", False):
|
||||||
error = {
|
error = {
|
||||||
"type": "bad_linked_input",
|
"type": "bad_linked_input",
|
||||||
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||||||
|
@ -272,6 +272,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Get the full path of a file in a folder, has to be a file
|
||||||
|
"""
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
@ -289,6 +292,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the full path of a file in a folder, has to be a file
|
||||||
|
"""
|
||||||
full_path = get_full_path(folder_name, filename)
|
full_path = get_full_path(folder_name, filename)
|
||||||
if full_path is None:
|
if full_path is None:
|
||||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
@ -390,3 +396,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
os.makedirs(full_output_folder, exist_ok=True)
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
counter = 1
|
counter = 1
|
||||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||||
|
|
||||||
|
def get_input_subfolders() -> list[str]:
|
||||||
|
"""Returns a list of all subfolder paths in the input directory, recursively.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of folder paths relative to the input directory, excluding the root directory
|
||||||
|
"""
|
||||||
|
input_dir = get_input_directory()
|
||||||
|
folders = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
return []
|
||||||
|
|
||||||
|
for root, dirs, _ in os.walk(input_dir):
|
||||||
|
rel_path = os.path.relpath(root, input_dir)
|
||||||
|
if rel_path != ".": # Only include non-root directories
|
||||||
|
# Normalize path separators to forward slashes
|
||||||
|
folders.append(rel_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
|
return sorted(folders)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
1
nodes.py
1
nodes.py
@ -2229,6 +2229,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_model_downscale.py",
|
"nodes_model_downscale.py",
|
||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
|
"nodes_train.py",
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
"nodes_perpneg.py",
|
"nodes_perpneg.py",
|
||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user