Merge branch 'comfyanonymous:master' into master

This commit is contained in:
Kallen Ding 2024-12-12 11:52:06 +08:00 committed by GitHub
commit a0b35e60a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
160 changed files with 75812 additions and 21495 deletions

View File

@ -3,8 +3,8 @@ name: Python Linting
on: [push, pull_request]
jobs:
pylint:
name: Run Pylint
ruff:
name: Run Ruff
runs-on: ubuntu-latest
steps:
@ -16,8 +16,8 @@ jobs:
with:
python-version: 3.x
- name: Install Pylint
run: pip install pylint
- name: Install Ruff
run: pip install ruff
- name: Run Pylint
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
- name: Run Ruff
run: ruff check .

View File

@ -20,7 +20,8 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
# os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
@ -31,9 +32,9 @@ jobs:
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
@ -45,28 +46,28 @@ jobs:
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
# test-win-nightly:
# strategy:
# fail-fast: true
# matrix:
# os: [windows]
# python_version: ["3.9", "3.10", "3.11", "3.12"]
# cuda_version: ["12.1"]
# torch_version: ["nightly"]
# include:
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
# runs-on: ${{ matrix.runner_label }}
# steps:
# - name: Test Workflows
# uses: comfy-org/comfy-action@main
# with:
# os: ${{ matrix.os }}
# python_version: ${{ matrix.python_version }}
# torch_version: ${{ matrix.torch_version }}
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
# comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:

View File

@ -28,7 +28,7 @@ jobs:
- name: Start ComfyUI server
run: |
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600
wait-for-it --service 127.0.0.1:8188 -t 30
working-directory: ComfyUI
- name: Check for unhandled exceptions in server log
run: |

View File

@ -1,3 +0,0 @@
[MESSAGES CONTROL]
disable=all
enable=eval-used

View File

@ -1 +1,22 @@
* @comfyanonymous
# Admins
* @comfyanonymous
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

View File

@ -39,6 +39,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system
@ -74,37 +75,39 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| `Ctrl` + `Enter` | Queue up current graph for generation |
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| `Ctrl` + `S` | Save workflow |
| `Ctrl` + `O` | Load workflow |
| `Ctrl` + `A` | Select all nodes |
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
| `Ctrl` + `M` | Mute/unmute selected nodes |
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| `Delete`/`Backspace` | Delete selected nodes |
| `Ctrl` + `Backspace` | Delete the current graph |
| `Space` | Move the canvas around when held and moving the cursor |
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
| `Ctrl` + `D` | Load default graph |
| `Alt` + `+` | Canvas Zoom in |
| `Alt` + `-` | Canvas Zoom out |
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| `P` | Pin/Unpin selected nodes |
| `Ctrl` + `G` | Group selected nodes |
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| `F` | Show/Hide menu |
| `.` | Fit view to selection (Whole graph when nothing is selected) |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
`Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing
@ -140,7 +143,7 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
@ -212,6 +215,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
# Notes
Only parts of the graph that have an output with all the correct inputs will be executed.

View File

@ -10,7 +10,6 @@ class InternalRoutes:
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self, prompt_server):

View File

@ -1,5 +1,6 @@
from app.logger import on_flush
import os
import shutil
class TerminalService:
@ -10,15 +11,27 @@ class TerminalService:
self.subscriptions = set()
on_flush(self.send_messages)
def get_terminal_size(self):
try:
size = os.get_terminal_size()
return (size.columns, size.lines)
except OSError:
try:
size = shutil.get_terminal_size()
return (size.columns, size.lines)
except OSError:
return (80, 24) # fallback to 80x24
def update_size(self):
sz = os.get_terminal_size()
columns, lines = self.get_terminal_size()
changed = False
if sz.columns != self.cols:
self.cols = sz.columns
if columns != self.cols:
self.cols = columns
changed = True
if sz.lines != self.rows:
self.rows = sz.lines
if lines != self.rows:
self.rows = lines
changed = True
if changed:

167
app/model_manager.py Normal file
View File

@ -0,0 +1,167 @@
from __future__ import annotations
import os
import time
import logging
import folder_paths
import glob
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
preview_files = self.get_model_previews(full_filename)
default_preview_file = preview_files[0] if len(preview_files) > 0 else None
if default_preview_file is None or not os.path.isfile(default_preview_file):
return web.Response(status=404)
try:
with Image.open(default_preview_file) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
result: list[str] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

View File

@ -36,7 +36,7 @@ class UserManager():
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

View File

@ -2,11 +2,9 @@
#and modified
import torch
import torch as th
import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)

120
comfy/cldm/dit_embedder.py Normal file
View File

@ -0,0 +1,120 @@
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@ -1,5 +1,5 @@
import torch
from typing import Dict, Optional
from typing import Optional
import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):

View File

@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")

View File

@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
}
class CLIPMLP(torch.nn.Module):
@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
bias=patch_bias,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module):
@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)

View File

@ -16,13 +16,18 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
@ -35,6 +40,8 @@ class ClipVisionModel():
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@ -49,9 +56,9 @@ class ClipVisionModel():
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")

View File

@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -0,0 +1,43 @@
# Comfy Typing
## Type hinting for ComfyUI Node development
This module provides type hinting and concrete convenience types for node developers.
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {"required": {}}
```
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
# Types
A few primary types are documented below. More complete information is available via the docstrings on each type.
## `IO`
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
- `ANY`: `"*"`
- `NUMBER`: `"FLOAT,INT"`
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
## `ComfyNodeABC`
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
### Type hinting for `INPUT_TYPES`
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
### `INPUT_TYPES` return dict
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
### Options for individual inputs
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)

View File

@ -1,5 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
class UnetApplyFunction(Protocol):
@ -30,3 +31,15 @@ class UnetParams(TypedDict):
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
__all__ = [
"UnetWrapperFunction",
UnetApplyConds.__name__,
UnetParams.__name__,
UnetApplyFunction.__name__,
IO.__name__,
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
]

View File

@ -0,0 +1,28 @@
from comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires an IDE configured with analysis paths etc to be worth looking at.
* Not intended for use in ComfyUI.
"""
DESCRIPTION = cleandoc(__doc__)
CATEGORY = "examples"
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"input_int": (IO.INT, {"defaultInput": True}),
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("input_plus_one",)
FUNCTION = "execute"
def execute(self, input_int: int):
return (input_int + 1,)

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View File

@ -0,0 +1,274 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict
from abc import ABC, abstractmethod
from enum import Enum
class StrEnum(str, Enum):
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
def __str__(self) -> str:
return self.value
class IO(StrEnum):
"""Node input/output data types.
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
"""
STRING = "STRING"
IMAGE = "IMAGE"
MASK = "MASK"
LATENT = "LATENT"
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
GUIDER = "GUIDER"
NOISE = "NOISE"
CLIP = "CLIP"
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
GLIGEN = "GLIGEN"
UPSCALE_MODEL = "UPSCALE_MODEL"
AUDIO = "AUDIO"
WEBCAM = "WEBCAM"
POINT = "POINT"
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
ANY = "*"
"""Always matches any type, but at a price.
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
"""
NUMBER = "FLOAT,INT"
"""A float or an int - could be either"""
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
"""Could be any of: string, float, int, or bool"""
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
"""
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: bool
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: str
"""Tooltip for the input (or widget), shown on pointer hover"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: float
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: float
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: float
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: float
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: bool
"""Use a multiline text box (``STRING``)"""
placeholder: str
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: Literal["PROMPT"]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: Literal["EXTRA_PNGINFO"]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: Literal["DYNPROMPT"]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
"""Describes all inputs that must be connected for the node to execute."""
optional: dict[str, tuple[IO, InputTypeOptions]]
"""Describes inputs which do not need to be connected."""
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
"""
DESCRIPTION: str
"""Node description, shown as a tooltip when hovering over the node.
Usage::
# Explicitly define the description
DESCRIPTION = "Example description here."
# Use the docstring of the node class.
DESCRIPTION = cleandoc(__doc__)
"""
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
@classmethod
@abstractmethod
def INPUT_TYPES(s) -> InputTypeDict:
"""Defines node inputs.
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
"""
return {"required": {}}
OUTPUT_NODE: bool
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
From the docs:
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
RETURN_TYPES: tuple[IO]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
"""
RETURN_NAMES: tuple[str]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
"""
class CheckLazyMixin:
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
def check_lazy_status(self, **kwargs) -> list[str]:
"""Returns a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need

View File

@ -35,6 +35,10 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -78,6 +82,8 @@ class ControlBase:
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@ -114,6 +120,14 @@ class ControlBase:
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
@ -129,6 +143,8 @@ class ControlBase:
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -181,7 +197,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__()
self.control_model = control_model
self.load_device = load_device
@ -196,11 +212,12 @@ class ControlNet(ControlBase):
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -224,6 +241,7 @@ class ControlNet(ControlBase):
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@ -427,6 +445,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@ -448,6 +467,82 @@ def load_controlnet_mmdit(sd, model_options={}):
return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
@ -560,7 +655,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
@ -674,10 +772,10 @@ class T2IAdapter(ControlBase):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height
def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:

View File

@ -1,10 +1,9 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange, tqdm
from tqdm.auto import trange
class NoiseScheduleVP:

690
comfy/hooks.py Normal file
View File

@ -0,0 +1,690 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
class EnumHookMode(enum.Enum):
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return True
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
strength = self._strength_clip
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
weights = self.weights_clip
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
c._strength_model = self._strength_model
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
class HookGroup:
def __init__(self):
self.hooks: list[Hook] = []
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
def contains(self, hook: Hook):
return hook in self.hooks
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
for t_range, keyframe in range_kfs:
all_ranges.append(t_range)
# turn list of ranges into boundaries
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
boundaries_set.add(0.0)
boundaries = sorted(boundaries_set)
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
# with real ranges defined, give appropriate hooks w/ keyframes for each range
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
for t_range in real_ranges:
hooks_schedule = []
for hook, val in scheduled_hooks.items():
keyframe = None
# check if is a keyframe that works for the current t_range
for stored_range, stored_kf in val:
# if stored start is less than current end, then fits - give it assigned keyframe
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
keyframe = stored_kf
break
hooks_schedule.append((hook, keyframe))
scheduled_keyframes.append((t_range, hooks_schedule))
return scheduled_keyframes
def reset(self):
for hook in self.hooks:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class HookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class HookKeyframeGroup:
def __init__(self):
self.keyframes: list[HookKeyframe] = []
self._current_keyframe: HookKeyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self._curr_t = -1.
# properties shadow those of HookWeightsKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_strength = eval_c.strength
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index and prev_strength != self._current_strength
class InterpolationMethod:
LINEAR = "linear"
EASE_IN = "ease_in"
EASE_OUT = "ease_out"
EASE_IN_OUT = "ease_in_out"
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
@classmethod
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
diff = num_to - num_from
if method == cls.LINEAR:
weights = torch.linspace(num_from, num_to, length)
elif method == cls.EASE_IN:
index = torch.linspace(0, 1, length)
weights = diff * np.power(index, 2) + num_from
elif method == cls.EASE_OUT:
index = torch.linspace(0, 1, length)
weights = diff * (1 - np.power(1 - index, 2)) + num_from
elif method == cls.EASE_IN_OUT:
index = torch.linspace(0, 1, length)
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
else:
raise ValueError(f"Unrecognized interpolation method '{method}'.")
if reverse:
weights = weights.flip(dims=(0,))
return weights
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list: list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
hook.weights = lora
return hook_group
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
patches_model = None
patches_clip = None
if weights_model is not None:
patches_model = {}
for key in weights_model:
patches_model[key] = ("model_as_lora", (weights_model[key],))
if weights_clip is not None:
patches_clip = {}
for key in weights_clip:
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
hook.weights = patches_model
hook.weights_clip = patches_clip
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
if discard_model_sampling:
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
if key.startswith("model_sampling"):
patches_model.pop(key, None)
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
hook_group = HookGroup()
hook = WeightHook()
hook_group.add(hook)
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
return cond
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
"end_percent": timestep_range[1]})
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
if mask is None:
return cond
set_area_to_bounds = False
if set_cond_area != 'default':
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(cond, {'mask': mask,
'set_area_to_bounds': set_area_to_bounds,
'mask_strength': strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def combine_with_new_conds(conds: list, new_conds: list):
combined_conds = []
for c, new_c in zip(conds, new_conds):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
# finally, apply mask to conditioning and store
final_conds.append(c)
return final_conds
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds

View File

@ -175,12 +175,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if sigma_down == 0:
x = denoised
else:
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
@ -192,19 +194,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
if sigmas[i + 1] == 0:
x = denoised
else:
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if eta > 0:
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
@ -280,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -306,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:

View File

@ -216,3 +216,139 @@ class Mochi(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]

View File

@ -2,7 +2,7 @@
import torch
from torch import nn
from typing import Literal, Dict, Any
from typing import Literal
import math
import comfy.ops
ops = comfy.ops.disable_weight_init

View File

@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
info = {
"hidden_states": [],
@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers
for layer in self.layers:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
for i, layer in enumerate(self.layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
mask=None,
return_info=False,
control=None,
transformer_options={},
**kwargs):
return self._forward(
x,

View File

@ -2,8 +2,8 @@
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from torch import Tensor
from typing import List, Union
from einops import rearrange
import math
import comfy.ops

View File

@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
from .common import LayerNorm2d_op

View File

@ -2,7 +2,7 @@ import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]

View File

@ -6,9 +6,7 @@ import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .layers import (timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit

View File

@ -20,6 +20,7 @@ import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
@ -29,6 +30,7 @@ class FluxParams:
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
guidance_embed: bool
@ -43,8 +45,9 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
self.patch_size = params.patch_size
self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@ -165,7 +168,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

25
comfy/ldm/flux/redux.py Normal file
View File

@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

View File

@ -1,7 +1,7 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn as nn

View File

@ -1,7 +1,7 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from functools import partial
import math

View File

@ -1,24 +1,17 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
class HunYuanControlNet(nn.Module):

View File

@ -1,8 +1,6 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
@ -287,7 +285,7 @@ class HunYuanDiT(nn.Module):
style=None,
return_dict=False,
control=None,
transformer_options=None,
transformer_options={},
):
"""
Forward pass of the encoder.
@ -315,8 +313,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@ -364,6 +361,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None
if control:
controls = control.get("output", None)
@ -375,9 +374,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)

View File

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops

View File

@ -0,0 +1,527 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
else:
self.cond_proj = None
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
if post_act_fn is None:
self.post_act = None
# else:
# self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
def forward(self, x):
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
fractional_positions = get_fractional_positions(indices_grid, max_pos)
start = 1
end = theta
device = fractional_positions.device
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for d in range(num_layers)
]
)
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Patchifier(ABC):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
@abstractmethod
def patchify(
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
) -> Tuple[Tensor, Tensor]:
pass
@abstractmethod
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
pass
@property
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier):
def patchify(
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents

View File

@ -0,0 +1,64 @@
from typing import Tuple, Union
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = ops.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight

View File

@ -0,0 +1,698 @@
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
# assert target_shape is not None, "target_shape must be provided"
sample = self.conv_in(sample, causal=self.causal)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True
) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states, causal=causal)
return hidden_states
class DepthToSpaceUpsample(nn.Module):
def __init__(self, dims, in_channels, stride, residual=False):
super().__init__()
self.stride = stride
self.out_channels = math.prod(stride) * in_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
)
self.residual = residual
def forward(self, x, causal: bool = True):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
hidden_states = self.norm2(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
input_tensor = self.norm3(input_tensor)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
)
self.per_channel_statistics = processor()
def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x):
return self.decoder(self.per_channel_statistics.un_normalize(x))

View File

@ -0,0 +1,82 @@
from typing import Tuple, Union
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
return ops.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -0,0 +1,195 @@
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."

View File

@ -0,0 +1,12 @@
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)

View File

@ -1,6 +1,6 @@
import torch
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution

View File

@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
if mask.shape[1] == 1:
s1 += mask
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
mask = mask_out[..., :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)

View File

@ -1,5 +1,3 @@
import logging
import math
from typing import Dict, Optional, List
import numpy as np

View File

@ -3,7 +3,6 @@ import math
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Any
import logging
from comfy import model_management

View File

@ -9,12 +9,12 @@ import logging
from .util import (
checkpoint,
avg_pool_nd,
zero_module,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.patcher_extension
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -47,6 +47,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x)
return x
@ -819,6 +828,13 @@ class UNetModel(nn.Module):
)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.

View File

@ -4,7 +4,6 @@ import numpy as np
from functools import partial
from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module):

View File

@ -8,7 +8,6 @@
# thanks!
import os
import math
import torch
import torch.nn as nn

View File

@ -234,6 +234,8 @@ def efficient_dot_product_attention(
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
if mask.shape[1] == 1:
return mask
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]

View File

@ -1,5 +1,5 @@
import functools
from typing import Callable, Iterable, Union
from typing import Iterable, Union
import torch
from einops import rearrange, repeat

View File

@ -33,7 +33,7 @@ LORA_CLIP_MAP = {
}
def load_lora(lora, to_load):
def load_lora(lora, to_load, log_missing=True):
patch_dict = {}
loaded_keys = set()
for x in to_load:
@ -49,10 +49,20 @@ def load_lora(lora, to_load):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
@ -72,6 +82,10 @@ def load_lora(lora, to_load):
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
@ -82,7 +96,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@ -193,9 +207,16 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
set_weight_name = "{}.set_weight".format(x)
set_weight = lora.get(set_weight_name, None)
if set_weight is not None:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
if log_missing:
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return patch_dict
@ -282,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
@ -344,6 +368,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map
@ -400,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
v = p[1]
@ -440,10 +470,22 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "model_as_lora":
target_weight: torch.Tensor = v[0]
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:

17
comfy/lora_convert.py Normal file
View File

@ -0,0 +1,17 @@
import torch
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
return sd

View File

@ -30,14 +30,19 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.model_management
import comfy.patcher_extension
import comfy.conds
import comfy.ops
from enum import Enum
from . import utils
import comfy.latent_formats
import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class ModelType(Enum):
EPS = 1
@ -94,6 +99,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
self.current_patcher: 'ModelPatcher' = None
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
@ -119,6 +125,13 @@ class BaseModel(torch.nn.Module):
self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._apply_model,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
@ -153,8 +166,7 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def extra_conds(self, **kwargs):
out = {}
def concat_cond(self, **kwargs):
if len(self.concat_keys) > 0:
cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
@ -193,7 +205,14 @@ class BaseModel(torch.nn.Module):
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
return data
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
if concat_cond is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
adm = self.encode_adm(**kwargs)
if adm is not None:
@ -523,9 +542,7 @@ class SD_X4Upscaler(BaseModel):
return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
def concat_cond(self, **kwargs):
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
@ -537,18 +554,15 @@ class IP2P:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
@ -709,6 +723,44 @@ class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def concat_cond(self, **kwargs):
try:
#Handle Flux control loras dynamically changing the img_in weight.
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
except:
#Some cases like tensorrt might not have the weights accessible
num_channels = self.model_config.unet_config["in_channels"]
out_channels = self.model_config.unet_config["out_channels"]
if num_channels <= out_channels:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
image = self.process_latent_in(image)
if num_channels <= out_channels * 2:
return image
#inpaint model
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
print(mask.shape)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
@ -734,3 +786,27 @@ class GenmoMochi(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
return out

View File

@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
patch_size = 2
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
@ -177,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@ -23,6 +23,8 @@ from comfy.cli_args import args
import torch
import sys
import platform
import weakref
import gc
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -287,11 +289,27 @@ def module_size(module):
class LoadedModel:
def __init__(self, model):
self.model = model
self._set_model(model)
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
self.currently_used = True
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
def _switch_parent(self):
model = self._parent_model()
if model is not None:
self._set_model(model)
@property
def model(self):
return self._model()
def model_memory(self):
return self.model.model_size()
@ -306,32 +324,23 @@ class LoadedModel:
return self.model_memory()
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
# if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram)
else:
try:
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True
return self.real_model
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
return real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
@ -344,18 +353,26 @@ class LoadedModel:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
def __eq__(self, other):
return self.model is other.model
def __del__(self):
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def is_dead(self):
return self.real_model() is not None and self.model is None
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
if m.device == device:
@ -386,38 +403,8 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return True
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []
@ -425,7 +412,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
@ -454,6 +441,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
inference_memory = minimum_inference_memory()
@ -466,11 +454,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models = set(models)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except:
@ -478,51 +464,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
loaded.currently_used = True
models_to_load.append(loaded)
else:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
else:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
for loaded_model in models_to_load:
to_unload = []
for i in range(len(current_loaded_models)):
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
current_loaded_models.pop(i).model.detach(unpatch_all=False)
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_already_loaded:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
logging.info("{} models unloaded.".format(len(models_l)))
for loaded_model in models_to_load:
model = loaded_model.model
@ -544,17 +514,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
return
def load_model_gpu(model):
return load_models_gpu([model])
@ -568,21 +529,35 @@ def loaded_models(only_currently_used=False):
output.append(m.model)
return output
def cleanup_models(keep_clone_weights_loaded=False):
def cleanup_models_gc():
do_gc = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
if do_gc:
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
#TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
if current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)
x.model_unload()
del x
def dtype_size(dtype):
@ -628,6 +603,10 @@ def maximum_vram_for_weights(device=None):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
return torch.float32
if args.fp64_unet:
return torch.float64
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@ -674,7 +653,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)

File diff suppressed because it is too large Load Diff

View File

@ -243,7 +243,7 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
return time_snr_shift(self.shift, 1.0 - percent)
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
@ -336,4 +336,4 @@ class ModelSamplingFlux(torch.nn.Module):
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
return flux_time_shift(self.shift, 1.0, 1.0 - percent)

View File

@ -269,7 +269,7 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)

156
comfy/patcher_extension.py Normal file
View File

@ -0,0 +1,156 @@
from __future__ import annotations
from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"
ON_PRE_RUN = "on_pre_run"
ON_PREPARE_STATE = "on_prepare_state"
ON_APPLY_HOOKS = "on_apply_hooks"
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
ON_INJECT_MODEL = "on_inject_model"
ON_EJECT_MODEL = "on_eject_model"
# callbacks dict is in the format:
# {"call_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
c_list.extend(callbacks.get(call_type, {}).get(key, []))
return c_list
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
for c in callbacks.get(call_type, {}).values():
c_list.extend(c)
return c_list
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
# wrappers dict is in the format:
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
return w_list
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
for w in wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list
class WrapperExecutor:
"""Handles call stack of wrappers around a function in an ordered manner."""
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
# NOTE: class_obj exists so that wrappers surrounding a class method can access
# the class instance at runtime via executor.class_obj
self.original = original
self.class_obj = class_obj
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, *args, **kwargs)
def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)
class PatcherInjection:
def __init__(self, inject: Callable, eject: Callable):
self.inject = inject
self.eject = eject
def copy_nested_dicts(input_dict: dict):
new_dict = input_dict.copy()
for key, value in input_dict.items():
if isinstance(value, dict):
new_dict[key] = copy_nested_dicts(value)
elif isinstance(value, list):
new_dict[key] = value.copy()
return new_dict
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
if copy_dict1:
merged_dict = copy_nested_dicts(dict1)
else:
merged_dict = dict1
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(value, curr_value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:
merged_dict[key] = value
return merged_dict

View File

@ -1,7 +1,15 @@
import torch
from __future__ import annotations
import uuid
import comfy.model_management
import comfy.conds
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
@ -10,9 +18,43 @@ def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
if isinstance(c[model_type], list):
models += c[model_type]
else:
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
if 'control' in c:
cnets.append(c['control'])
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
if cnet.extra_hooks is not None:
_list.append(cnet.extra_hooks)
if cnet.previous_controlnet is None:
return _list
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
hooks_list = []
cnets = set(cnets)
for base_cnet in cnets:
get_extra_hooks_from_cnet(base_cnet, hooks_list)
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
return hooks_dict
def convert_cond(cond):
out = []
for c in cond:
@ -22,17 +64,22 @@ def convert_cond(cond):
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets)
@ -43,7 +90,9 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
return models, inference_memory
def cleanup_additional_models(models):
@ -53,10 +102,11 @@ def cleanup_additional_models(models):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device
real_model = None
real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
@ -72,3 +122,14 @@ def cleanup_models(conds, models):
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)

View File

@ -1,11 +1,21 @@
from __future__ import annotations
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
import torch
import collections
from comfy import model_management
import math
import logging
import comfy.samplers
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import scipy.stats
import numpy
@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)
patches = None
@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
def cond_equal_size(c1, c2):
if c1 is c2:
@ -138,110 +149,184 @@ def cond_cat(c_list):
return out
def calc_cond_batch(model, conds, x_in, timestep, model_options):
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
# need to figure out remaining unmasked area for conds
default_mults = []
for _ in default_conds:
default_mults.append(torch.ones_like(x_in))
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
for lora_hooks, to_run in hooked_to_run.items():
for cond_obj, i in to_run:
# if no default_cond for cond_type, do nothing
if len(default_conds[i]) == 0:
continue
area: list[int] = cond_obj.area
if area is not None:
curr_default_mult: torch.Tensor = default_mults[i]
dims = len(area) // 2
for i in range(dims):
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
curr_default_mult -= cond_obj.mult
else:
default_mults[i] -= cond_obj.mult
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
for i, mult in enumerate(default_mults):
# if no default_cond for cond type, do nothing
if len(default_conds[i]) == 0:
continue
torch.nn.functional.relu(mult, inplace=True)
# if mult is all zeros, then don't add default_cond
if torch.max(mult) == 0.0:
continue
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
p = p._replace(mult=mult)
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
to_run = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
to_run += [(p, i)]
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
model.current_patcher.prepare_state(timestep)
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
transformer_options["patches"] = patches
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
@ -261,7 +346,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
@ -500,10 +585,15 @@ def calculate_start_end_timesteps(model, conds):
timestep_start = None
timestep_end = None
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
# handle clip hook schedule, if needed
if 'clip_start_percent' in x:
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
else:
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None):
n = x.copy()
@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
for c in conds[k]:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook.initialize_timesteps(model)
for k in conds:
pre_run_control(model, conds[k])
@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
# determine which ControlNets have extra_hooks that should be combined with normal hooks
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
for k in conds:
for kk in conds[k]:
if 'control' in kk:
control: 'ControlBase' = kk['control']
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
to_replace = hook_replacement.setdefault((control, hooks), [])
to_replace.append(kk)
# if nothing to replace, do nothing
if len(hook_replacement) == 0:
return
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
# on the cond dicts
for key, conds_to_modify in hook_replacement.items():
control = key[0]
hooks = key[1]
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
# if combined hooks are not None, set as new hooks for all relevant conds
if hooks is not None:
for cond in conds_to_modify:
cond['hooks'] = hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
for k in conds:
for kk in conds[k]:
hooks_set.add(kk.get('hooks', None))
return len(hooks_set)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_patcher: 'ModelPatcher' = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@ -714,19 +847,17 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed}
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample,
sampler,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
@ -737,14 +868,48 @@ class CFGGuider:
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
preprocess_conds_hooks(self.conds)
try:
orig_model_options = self.model_options
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
orig_hook_mode = self.model_patcher.hook_mode
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
del self.conds
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)

View File

@ -1,13 +1,16 @@
from __future__ import annotations
import torch
from enum import Enum
import logging
from comfy import model_management
from comfy.utils import ProgressBar
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import yaml
import comfy.utils
@ -27,12 +30,17 @@ import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.model_patcher
import comfy.lora
import comfy.lora_convert
import comfy.hooks
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = {}
if model is not None:
@ -40,6 +48,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
@ -92,9 +101,13 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self):
@ -103,6 +116,8 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
@ -114,6 +129,69 @@ class CLIP:
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
pooled_dict["hooks"] = self.apply_hooks_to_conds
return pooled_dict
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
if all_hooks is None or not self.use_clip_schedule:
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
return_pooled = "unprojected" if unprojected else True
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
cond = pooled_dict.pop("cond")
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
all_cond_pooled.append([cond, pooled_dict])
else:
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options()
@ -131,6 +209,7 @@ class CLIP:
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
self.add_hooks_to_dict(out)
return out
if return_pooled:
@ -257,6 +336,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -356,7 +443,9 @@ class VAE:
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in)
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@ -420,6 +509,12 @@ class VAE:
def get_sd(self):
return self.first_stage_model.state_dict()
def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@ -433,6 +528,8 @@ def load_style_model(ckpt_path):
keys = model_data.keys()
if "style_embedding" in keys:
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "redux_down.weight" in keys:
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
@ -446,6 +543,7 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5
FLUX = 6
MOCHI = 7
LTXV = 8
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@ -524,6 +622,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer

View File

@ -90,8 +90,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
with open(textmodel_json_config) as f:
config = json.load(f)
if isinstance(textmodel_json_config, dict):
config = textmodel_json_config
else:
with open(textmodel_json_config) as f:
config = json.load(f)
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
@ -196,11 +199,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1)
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == end_token:
if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break
attention_mask_model = None
@ -411,22 +421,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
self.end_token = None
empty = self.tokenizer('')["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
if has_end_token:
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
if has_end_token:
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token
@ -451,13 +464,16 @@ class SDTokenizer:
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
split_embed = embedding_name.split(' ')
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):])
return (embed, "")
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -474,7 +490,12 @@ class SDTokenizer:
#tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = unescape_important(weighted_segment).replace("\n", " ")
split = to_tokenize.split(' {}'.format(self.embedding_identifier))
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
@ -493,8 +514,11 @@ class SDTokenizer:
word = leftover
else:
continue
end = 999999999999
if self.end_token is not None:
end = -1
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
#reshape token array to CLIP input size
batched_tokens = []
@ -505,18 +529,24 @@ class SDTokenizer:
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
if self.end_token is not None:
has_end_token = 1
else:
has_end_token = 0
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
if len(t_group) + len(batch) > self.max_length - has_end_token:
remaining_length = self.max_length - len(batch) - has_end_token
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch
@ -529,7 +559,8 @@ class SDTokenizer:
t_group = []
#fill last batch
batch.append((self.end_token, 1.0, 0))
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:

View File

@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
from . import supported_models_base
from . import latent_formats
@ -658,6 +659,15 @@ class Flux(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxInpaint(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
"in_channels": 96,
}
supported_inference_dtypes = [torch.bfloat16, torch.float32]
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
@ -702,7 +712,34 @@ class GenmoMochi(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
class LTXV(supported_models_base.BASE):
unet_config = {
"image_model": "ltxv",
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
sampling_settings = {
"shift": 2.37,
}
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 2.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
models += [SVD_img2vid]

18
comfy/text_encoders/lt.py Normal file
View File

@ -0,0 +1,18 @@
from comfy import sd1_clip
import os
from transformers import T5TokenizerFast
import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)

View File

@ -1,4 +1,3 @@
import os
import torch
class SPieceTokenizer:

View File

@ -209,6 +209,11 @@ class T5Stack(torch.nn.Module):
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
past_bias = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.block) + intermediate_output
for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias, optimized_attention)
if i == intermediate_output:

View File

@ -46,7 +46,13 @@ def load_torch_file(ckpt, safe_load=False, device=None):
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
if len(pl_sd) == 1:
key = list(pl_sd.keys())[0]
sd = pl_sd[key]
if not isinstance(sd, dict):
sd = pl_sd
else:
sd = pl_sd
return sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@ -0,0 +1,39 @@
from __future__ import annotations
def validate_node_input(
received_type: str, input_type: str, strict: bool = False
) -> bool:
"""
received_type and input_type are both strings of the form "T1,T2,...".
If strict is True, the input_type must contain the received_type.
For example, if received_type is "STRING" and input_type is "STRING,INT",
this will return True. But if received_type is "STRING,INT" and input_type is
"INT", this will return False.
If strict is False, the input_type must have overlap with the received_type.
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
this will return True.
Supports pre-union type extension behaviour of ``__ne__`` overrides.
"""
# If the types are exactly the same, we can return immediately
# Use pre-union behaviour: inverse of `__ne__`
if not received_type != input_type:
return True
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False
# Split the type strings into sets for comparison
received_types = set(t.strip() for t in received_type.split(","))
input_types = set(t.strip() for t in input_type.split(","))
if strict:
# In strict mode, all received types must be in the input types
return received_types.issubset(input_types)
else:
# In non-strict mode, there must be at least one type in common
return len(received_types.intersection(input_types)) > 0

View File

@ -2,8 +2,7 @@ import comfy.samplers
import comfy.utils
import torch
import numpy as np
from tqdm.auto import trange, tqdm
import math
from tqdm.auto import trange
@torch.no_grad()

View File

@ -1,4 +1,3 @@
import torch
from nodes import MAX_RESOLUTION
class CLIPTextEncodeSDXLRefiner:
@ -17,8 +16,7 @@ class CLIPTextEncodeSDXLRefiner:
def encode(self, clip, ascore, width, height, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
class CLIPTextEncodeSDXL:
@classmethod
@ -47,8 +45,7 @@ class CLIPTextEncodeSDXL:
tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,

View File

@ -1,4 +1,3 @@
import numpy as np
import torch
import comfy.utils
from enum import Enum

View File

@ -18,10 +18,7 @@ class CLIPTextEncodeFlux:
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
output["guidance"] = guidance
return ([[cond, output]], )
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
class FluxGuidance:
@classmethod

744
comfy_extras/nodes_hooks.py Normal file
View File

@ -0,0 +1,744 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
import torch
from collections.abc import Iterable
if TYPE_CHECKING:
from comfy.sd import CLIP
import comfy.hooks
import comfy.sd
import comfy.utils
import folder_paths
###########################################
# Mask, Combine, and Hook Conditioning
#------------------------------------------
class PairConditioningSetProperties:
NodeId = 'PairConditioningSetProperties'
NodeName = 'Cond Pair Set Props'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive_NEW": ("CONDITIONING", ),
"negative_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_properties"
def set_properties(self, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative)
class PairConditioningSetPropertiesAndCombine:
NodeId = 'PairConditioningSetPropertiesAndCombine'
NodeName = 'Cond Pair Set Props Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"positive_NEW": ("CONDITIONING", ),
"negative_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_properties"
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative)
class ConditioningSetProperties:
NodeId = 'ConditioningSetProperties'
NodeName = 'Cond Set Props'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_properties"
def set_properties(self, cond_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_cond,)
class ConditioningSetPropertiesAndCombine:
NodeId = 'ConditioningSetPropertiesAndCombine'
NodeName = 'Cond Set Props Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING", ),
"cond_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_properties"
def set_properties(self, cond, cond_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_cond,)
class PairConditioningCombine:
NodeId = 'PairConditioningCombine'
NodeName = 'Cond Pair Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive_A": ("CONDITIONING",),
"negative_A": ("CONDITIONING",),
"positive_B": ("CONDITIONING",),
"negative_B": ("CONDITIONING",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "combine"
def combine(self, positive_A, negative_A, positive_B, negative_B):
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
return (final_positive, final_negative,)
class PairConditioningSetDefaultAndCombine:
NodeId = 'PairConditioningSetDefaultCombine'
NodeName = 'Cond Pair Set Default Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"positive_DEFAULT": ("CONDITIONING",),
"negative_DEFAULT": ("CONDITIONING",),
},
"optional": {
"hooks": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_default_and_combine"
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
hooks: comfy.hooks.HookGroup=None):
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
hooks=hooks)
return (final_positive, final_negative)
class ConditioningSetDefaultAndCombine:
NodeId = 'ConditioningSetDefaultCombine'
NodeName = 'Cond Set Default Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING",),
"cond_DEFAULT": ("CONDITIONING",),
},
"optional": {
"hooks": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_default_and_combine"
def set_default_and_combine(self, cond, cond_DEFAULT,
hooks: comfy.hooks.HookGroup=None):
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
hooks=hooks)
return (final_conditioning,)
class SetClipHooks:
NodeId = 'SetClipHooks'
NodeName = 'Set CLIP Hooks'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip": ("CLIP",),
"apply_to_conds": ("BOOLEAN", {"default": True}),
"schedule_clip": ("BOOLEAN", {"default": False})
},
"optional": {
"hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CLIP",)
CATEGORY = "advanced/hooks/clip"
FUNCTION = "apply_hooks"
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
if apply_to_conds:
clip.apply_hooks_to_conds = hooks
clip.patcher.forced_hooks = hooks.clone()
clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
return (clip,)
class ConditioningTimestepsRange:
NodeId = 'ConditioningTimestepsRange'
NodeName = 'Timesteps Range'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
CATEGORY = "advanced/hooks"
FUNCTION = "create_range"
def create_range(self, start_percent: float, end_percent: float):
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
#------------------------------------------
###########################################
###########################################
# Create Hooks
#------------------------------------------
class CreateHookLora:
NodeId = 'CreateHookLora'
NodeName = 'Create Hook LoRA'
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook"
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
if prev_hooks is None:
prev_hooks = comfy.hooks.HookGroup()
prev_hooks.clone()
if strength_model == 0 and strength_clip == 0:
return (prev_hooks,)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
return (prev_hooks.clone_and_combine(hooks),)
class CreateHookLoraModelOnly(CreateHookLora):
NodeId = 'CreateHookLoraModelOnly'
NodeName = 'Create Hook LoRA (MO)'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook_model_only"
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
class CreateHookModelAsLora:
NodeId = 'CreateHookModelAsLora'
NodeName = 'Create Hook Model as LoRA'
def __init__(self):
# when not None, will be in following format:
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
self.loaded_weights = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook"
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
prev_hooks: comfy.hooks.HookGroup=None):
if prev_hooks is None:
prev_hooks = comfy.hooks.HookGroup()
prev_hooks.clone()
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
weights_model = None
weights_clip = None
if self.loaded_weights is not None:
if self.loaded_weights[0] == ckpt_path:
weights_model = self.loaded_weights[1]
weights_clip = self.loaded_weights[2]
else:
temp = self.loaded_weights
self.loaded_weights = None
del temp
if weights_model is None:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
strength_model=strength_model, strength_clip=strength_clip)
return (prev_hooks.clone_and_combine(hooks),)
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
NodeId = 'CreateHookModelAsLoraModelOnly'
NodeName = 'Create Hook Model as LoRA (MO)'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook_model_only"
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
prev_hooks: comfy.hooks.HookGroup=None):
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
#------------------------------------------
###########################################
###########################################
# Schedule Hooks
#------------------------------------------
class SetHookKeyframes:
NodeId = 'SetHookKeyframes'
NodeName = 'Set Hook Keyframes'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"hooks": ("HOOKS",),
},
"optional": {
"hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "set_hook_keyframes"
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
if hook_kf is not None:
hooks = hooks.clone()
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
return (hooks,)
class CreateHookKeyframe:
NodeId = 'CreateHookKeyframe'
NodeName = 'Create Hook Keyframe'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "create_hook_keyframe"
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
if prev_hook_kf is None:
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
prev_hook_kf = prev_hook_kf.clone()
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
prev_hook_kf.add(keyframe)
return (prev_hook_kf,)
class CreateHookKeyframesInterpolated:
NodeId = 'CreateHookKeyframesInterpolated'
NodeName = 'Create Hook Keyframes Interp.'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
"interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
"print_keyframes": ("BOOLEAN", {"default": False}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "create_hook_keyframes"
def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
start_percent: float, end_percent: float, keyframes_count: int,
print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
if prev_hook_kf is None:
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
prev_hook_kf = prev_hook_kf.clone()
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
method=comfy.hooks.InterpolationMethod.LINEAR)
strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
is_first = True
for percent, strength in zip(percents, strengths):
guarantee_steps = 0
if is_first:
guarantee_steps = 1
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
NodeId = 'CreateHookKeyframesFromFloats'
NodeName = 'Create Hook Keyframes From Floats'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"print_keyframes": ("BOOLEAN", {"default": False}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "create_hook_keyframes"
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
start_percent: float, end_percent: float,
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
if prev_hook_kf is None:
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
prev_hook_kf = prev_hook_kf.clone()
if type(floats_strength) in (float, int):
floats_strength = [float(floats_strength)]
elif isinstance(floats_strength, Iterable):
pass
else:
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
method=comfy.hooks.InterpolationMethod.LINEAR)
is_first = True
for percent, strength in zip(percents, floats_strength):
guarantee_steps = 0
if is_first:
guarantee_steps = 1
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
#------------------------------------------
###########################################
class SetModelHooksOnCond:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"conditioning": ("CONDITIONING",),
"hooks": ("HOOKS",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/manual"
FUNCTION = "attach_hook"
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
###########################################
# Combine Hooks
#------------------------------------------
class CombineHooks:
NodeId = 'CombineHooks2'
NodeName = 'Combine Hooks [2]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
class CombineHooksFour:
NodeId = 'CombineHooks4'
NodeName = 'Combine Hooks [4]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
"hooks_C": ("HOOKS",),
"hooks_D": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None,
hooks_C: comfy.hooks.HookGroup=None,
hooks_D: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
class CombineHooksEight:
NodeId = 'CombineHooks8'
NodeName = 'Combine Hooks [8]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
"hooks_C": ("HOOKS",),
"hooks_D": ("HOOKS",),
"hooks_E": ("HOOKS",),
"hooks_F": ("HOOKS",),
"hooks_G": ("HOOKS",),
"hooks_H": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None,
hooks_C: comfy.hooks.HookGroup=None,
hooks_D: comfy.hooks.HookGroup=None,
hooks_E: comfy.hooks.HookGroup=None,
hooks_F: comfy.hooks.HookGroup=None,
hooks_G: comfy.hooks.HookGroup=None,
hooks_H: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
#------------------------------------------
###########################################
node_list = [
# Create
CreateHookLora,
CreateHookLoraModelOnly,
CreateHookModelAsLora,
CreateHookModelAsLoraModelOnly,
# Scheduling
SetHookKeyframes,
CreateHookKeyframe,
CreateHookKeyframesInterpolated,
CreateHookKeyframesFromFloats,
# Combine
CombineHooks,
CombineHooksFour,
CombineHooksEight,
# Attach
ConditioningSetProperties,
ConditioningSetPropertiesAndCombine,
PairConditioningSetProperties,
PairConditioningSetPropertiesAndCombine,
ConditioningSetDefaultAndCombine,
PairConditioningSetDefaultAndCombine,
PairConditioningCombine,
SetClipHooks,
# Other
ConditioningTimestepsRange,
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName

View File

@ -15,9 +15,7 @@ class CLIPTextEncodeHunyuanDiT:
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
return (clip.encode_from_tokens_scheduled(tokens), )
NODE_CLASS_MAPPINGS = {

184
comfy_extras/nodes_lt.py Normal file
View File

@ -0,0 +1,184 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import math
class EmptyLTXVLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video/ltxv"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
class LTXVImgToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
return (positive, negative, {"samples": latent}, )
class LTXVConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning/video_models"
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative)
class ModelSamplingLTXV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
},
"optional": {"latent": ("LATENT",), }
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone()
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (tokens) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class LTXVScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"stretch": ("BOOLEAN", {
"default": True,
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": (
"FLOAT",
{
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
"tooltip": "The terminal value of the sigmas after stretching."
},
),
},
"optional": {"latent": ("LATENT",), }
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
sigmas = torch.linspace(1.0, 0.0, steps + 1)
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = (tokens) * mm + b
power = 1
sigmas = torch.where(
sigmas != 0,
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
0,
)
# Stretch sigmas so that its final value matches the given terminal value.
if stretch:
non_zero_mask = sigmas != 0
non_zero_sigmas = sigmas[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched
return (sigmas,)
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
}

View File

@ -0,0 +1,41 @@
import torch
import torch.nn.functional as F
class Mahiro:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
def patch(self, model):
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
cond_p: torch.Tensor = args['cond_denoised']
uncond_p: torch.Tensor = args['uncond_denoised']
#naive leap
leap = cond_p * scale
#sim with uncond leap
u_leap = uncond_p * scale
cfg = args["denoised"]
merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, )
NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
}

View File

@ -1,4 +1,3 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats

View File

@ -1,4 +1,3 @@
import torch
import comfy.utils
class PatchModelAddDownscale:

View File

@ -174,6 +174,28 @@ class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patchify_proj."] = argument
arg_dict["adaln_single."] = argument
arg_dict["caption_projection."] = argument
for i in range(28):
arg_dict["transformer_blocks.{}.".format(i)] = argument
arg_dict["scale_shift_table"] = argument
arg_dict["proj_out."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -183,4 +205,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
}

View File

@ -82,8 +82,7 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
return (clip.encode_from_tokens_scheduled(tokens), )
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):

View File

@ -16,7 +16,8 @@ class SkipLayerGuidanceDiT:
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
@ -26,7 +27,7 @@ class SkipLayerGuidanceDiT:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
# check if layer is comma separated integers
def skip(args, extra_args):
return args
@ -65,6 +66,11 @@ class SkipLayerGuidanceDiT:
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
if rescaling_scale != 0:
factor = cond_pred.std() / cfg_result.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
cfg_result *= factor
return cfg_result
m = model.clone()

View File

@ -1,4 +1,3 @@
import os
import logging
from spandrel import ModelLoader, ImageModelDescriptor
from comfy import model_management

View File

@ -1,7 +1,5 @@
from PIL import Image, ImageOps
from io import BytesIO
from PIL import Image
import numpy as np
import struct
import comfy.utils
import time

View File

@ -16,7 +16,7 @@ import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
SUCCESS = 0
@ -480,7 +480,7 @@ class PromptExecutor:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
@ -527,7 +527,6 @@ class PromptExecutor:
comfy.model_management.unload_all_models()
def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
@ -589,8 +588,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}"
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",

View File

@ -5,20 +5,24 @@ import ctypes
import logging
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, 'rb') as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
def fix_pytorch_libomp():
"""
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
"""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
try:
mydll = ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError as e:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
with open(test_file, "rb") as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
break
try:
mydll = ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError as e:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

View File

@ -4,7 +4,7 @@ import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from typing import Literal
from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@ -133,7 +133,7 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())

View File

@ -1,7 +1,5 @@
import torch
from PIL import Image
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.model_management

View File

@ -8,6 +8,11 @@ import time
from comfy.cli_args import args
from app.logger import setup_logger
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
setup_logger(log_level=args.verbose)
@ -82,7 +87,8 @@ if __name__ == "__main__":
if args.windows_standalone_build:
try:
import fix_torch
from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except:
pass
@ -154,7 +160,6 @@ def prompt_worker(q, server):
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time

View File

@ -1,2 +0,0 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@ -1,234 +0,0 @@
#NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import folder_names_and_paths, get_folder_paths
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file
try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(model_name, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(model_name, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(model_name, status)
last_update_time = time.time()
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
os.rename(temp_file_path, file_path)
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(model_name, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return status
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False
# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False
# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False
# Ensure the filename isn't too long
if len(filename) > 255:
return False
return True

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import torch
import os
@ -10,7 +11,7 @@ import time
import random
import logging
from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
@ -24,6 +25,7 @@ import comfy.sample
import comfy.sd
import comfy.utils
import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import comfy.clip_vision
@ -44,16 +46,16 @@ def interrupt_processing(value=True):
MAX_RESOLUTION=16384
class CLIPTextEncode:
class CLIPTextEncode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s):
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_TYPES = (IO.CONDITIONING,)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"
@ -62,9 +64,8 @@ class CLIPTextEncode:
def encode(self, clip, text):
tokens = clip.tokenize(text)
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
return (clip.encode_from_tokens_scheduled(tokens), )
class ConditioningCombine:
@classmethod
@ -301,7 +302,8 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4:
overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
@ -382,6 +384,7 @@ class InpaintModelConditioning:
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
@ -390,7 +393,7 @@ class InpaintModelConditioning:
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@ -414,7 +417,8 @@ class InpaintModelConditioning:
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
if noise_mask:
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
@ -640,9 +644,7 @@ class LoraLoader:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
@ -895,7 +897,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -913,6 +915,8 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
@ -966,15 +970,19 @@ class CLIPVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",)
"image": ("IMAGE",),
"crop": (["center", "none"],)
}}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip_vision, image):
output = clip_vision.encode_image(image)
def encode(self, clip_vision, image, crop):
crop_image = True
if crop != "center":
crop_image = False
output = clip_vision.encode_image(image, crop=crop_image)
return (output,)
class StyleModelLoader:
@ -999,14 +1007,19 @@ class StyleModelApply:
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply":
cond *= strength
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
@ -2134,6 +2147,9 @@ def init_builtin_extra_nodes():
"nodes_torch_compile.py",
"nodes_mochi.py",
"nodes_slg.py",
"nodes_mahiro.py",
"nodes_lt.py",
"nodes_hooks.py",
]
import_failed = []

Some files were not shown because too many files have changed in this diff Show More