mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
a0b35e60a3
@ -3,8 +3,8 @@ name: Python Linting
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
ruff:
|
||||
name: Run Ruff
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@ -16,8 +16,8 @@ jobs:
|
||||
with:
|
||||
python-version: 3.x
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
- name: Install Ruff
|
||||
run: pip install ruff
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
53
.github/workflows/test-ci.yml
vendored
53
.github/workflows/test-ci.yml
vendored
@ -20,7 +20,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
@ -31,9 +32,9 @@ jobs:
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
@ -45,28 +46,28 @@ jobs:
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-win-nightly:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
# test-win-nightly:
|
||||
# strategy:
|
||||
# fail-fast: true
|
||||
# matrix:
|
||||
# os: [windows]
|
||||
# python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# cuda_version: ["12.1"]
|
||||
# torch_version: ["nightly"]
|
||||
# include:
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
# runs-on: ${{ matrix.runner_label }}
|
||||
# steps:
|
||||
# - name: Test Workflows
|
||||
# uses: comfy-org/comfy-action@main
|
||||
# with:
|
||||
# os: ${{ matrix.os }}
|
||||
# python_version: ${{ matrix.python_version }}
|
||||
# torch_version: ${{ matrix.torch_version }}
|
||||
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
# comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
|
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
|
23
CODEOWNERS
23
CODEOWNERS
@ -1 +1,22 @@
|
||||
* @comfyanonymous
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Frontend assets
|
||||
/web/ @huchenlei @webfiltered @pythongosssss
|
||||
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
71
README.md
71
README.md
@ -39,6 +39,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Asynchronous Queue system
|
||||
@ -74,37 +75,39 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
| Keybind | Explanation |
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Alt + Enter | Cancel current generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| P | Pin/Unpin selected nodes |
|
||||
| Ctrl + G | Group selected nodes |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||
| `Ctrl` + `S` | Save workflow |
|
||||
| `Ctrl` + `O` | Load workflow |
|
||||
| `Ctrl` + `A` | Select all nodes |
|
||||
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| `Delete`/`Backspace` | Delete selected nodes |
|
||||
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||
| `Space` | Move the canvas around when held and moving the cursor |
|
||||
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||
| `Ctrl` + `D` | Load default graph |
|
||||
| `Alt` + `+` | Canvas Zoom in |
|
||||
| `Alt` + `-` | Canvas Zoom out |
|
||||
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| `P` | Pin/Unpin selected nodes |
|
||||
| `Ctrl` + `G` | Group selected nodes |
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| `F` | Show/Hide menu |
|
||||
| `.` | Fit view to selection (Whole graph when nothing is selected) |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Shift + Drag | Move multiple wires at once |
|
||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@ -140,7 +143,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
@ -212,6 +215,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
|
||||
|
||||
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
Only parts of the graph that have an output with all the correct inputs will be executed.
|
||||
|
@ -10,7 +10,6 @@ class InternalRoutes:
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
@ -10,15 +11,27 @@ class TerminalService:
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
sz = os.get_terminal_size()
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
if sz.columns != self.cols:
|
||||
self.cols = sz.columns
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if sz.lines != self.rows:
|
||||
self.rows = sz.lines
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
|
167
app/model_manager.py
Normal file
167
app/model_manager.py
Normal file
@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
||||
self.cache[key] = value
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache.clear()
|
||||
|
||||
def add_routes(self, routes):
|
||||
# NOTE: This is an experiment to replace `/models`
|
||||
@routes.get("/experiment/models")
|
||||
async def get_model_folders(request):
|
||||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||
folder_black_list = ["configs", "custom_nodes"]
|
||||
output_folders: list[dict] = []
|
||||
for folder in model_types:
|
||||
if folder in folder_black_list:
|
||||
continue
|
||||
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
||||
return web.json_response(output_folders)
|
||||
|
||||
# NOTE: This is an experiment to replace `/models/{folder}`
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
|
||||
preview_files = self.get_model_previews(full_filename)
|
||||
default_preview_file = preview_files[0] if len(preview_files) > 0 else None
|
||||
if default_preview_file is None or not os.path.isfile(default_preview_file):
|
||||
return web.Response(status=404)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview_file) as img:
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format="WEBP")
|
||||
img_bytes.seek(0)
|
||||
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
||||
except:
|
||||
return web.Response(status=404)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
output_list: list[dict] = []
|
||||
|
||||
for index, folder in enumerate(folders[0]):
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
out = self.cache_model_file_list_(folder)
|
||||
if out is None:
|
||||
out = self.recursive_search_models_(folder, index)
|
||||
self.set_cache(folder, out)
|
||||
output_list.extend(out[0])
|
||||
|
||||
return output_list
|
||||
|
||||
def cache_model_file_list_(self, folder: str):
|
||||
model_file_list_cache = self.get_cache(folder)
|
||||
|
||||
if model_file_list_cache is None:
|
||||
return None
|
||||
if not os.path.isdir(folder):
|
||||
return None
|
||||
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
||||
return None
|
||||
for x in model_file_list_cache[1]:
|
||||
time_modified = model_file_list_cache[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
return model_file_list_cache
|
||||
|
||||
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}, time.perf_counter()
|
||||
|
||||
excluded_dir_names = [".git"]
|
||||
# TODO use settings
|
||||
include_hidden_files = False
|
||||
|
||||
result: list[str] = []
|
||||
dirs: dict[str, float] = {}
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
if not include_hidden_files:
|
||||
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
||||
filenames = [f for f in filenames if not f.startswith(".")]
|
||||
|
||||
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path: str = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
return []
|
||||
|
||||
basename = os.path.splitext(filepath)[0]
|
||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||
image_files = filter_files_content_types(match_files, "image")
|
||||
|
||||
result: list[str] = []
|
||||
|
||||
for filename in image_files:
|
||||
_basename = os.path.splitext(filename)[0]
|
||||
if _basename == basename:
|
||||
result.append(filename)
|
||||
if _basename == f"{basename}.preview":
|
||||
result.append(filename)
|
||||
return result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.clear_cache()
|
@ -36,7 +36,7 @@ class UserManager():
|
||||
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
@ -2,11 +2,9 @@
|
||||
#and modified
|
||||
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ..ldm.modules.diffusionmodules.util import (
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
120
comfy/cldm/dit_embedder.py
Normal file
120
comfy/cldm/dit_embedder.py
Normal file
@ -0,0 +1,120 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class ControlNetEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
adm_in_channels: int,
|
||||
num_layers: int,
|
||||
main_model_double: int,
|
||||
double_y_emb: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_model_double = main_model_double
|
||||
self.dtype = dtype
|
||||
self.hidden_size = num_attention_heads * attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=pos_embed_max_size is None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_y_emb = double_y_emb
|
||||
if self.double_y_emb:
|
||||
self.orig_y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
DismantledBlock(
|
||||
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
)
|
||||
|
||||
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||
# TODO double check this logic when 8b
|
||||
self.use_y_embedder = True
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.pos_embed_input = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
hint = None,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
x_shape = list(x.shape)
|
||||
x = self.x_embedder(x)
|
||||
if not self.double_y_emb:
|
||||
h = (x_shape[-2] + 1) // self.patch_size
|
||||
w = (x_shape[-1] + 1) // self.patch_size
|
||||
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
if self.double_y_emb:
|
||||
y = self.orig_y_embedder(y)
|
||||
y = self.y_embedder(y)
|
||||
c = c + y
|
||||
|
||||
x = x + self.pos_embed_input(hint)
|
||||
|
||||
block_out = ()
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||
for i in range(len(self.transformer_blocks)):
|
||||
out = self.transformer_blocks[i](x, c)
|
||||
if not self.double_y_emb:
|
||||
x = out
|
||||
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||
|
||||
return {"output": block_out}
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
|
||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
|
@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
||||
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
|
@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
@ -16,13 +16,18 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
@ -35,6 +40,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@ -49,9 +56,9 @@ class ClipVisionModel():
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
|
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
43
comfy/comfy_types/README.md
Normal file
43
comfy/comfy_types/README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Comfy Typing
|
||||
## Type hinting for ComfyUI Node development
|
||||
|
||||
This module provides type hinting and concrete convenience types for node developers.
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {"required": {}}
|
||||
```
|
||||
|
||||
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
|
||||
|
||||
# Types
|
||||
A few primary types are documented below. More complete information is available via the docstrings on each type.
|
||||
|
||||
## `IO`
|
||||
|
||||
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
|
||||
|
||||
- `ANY`: `"*"`
|
||||
- `NUMBER`: `"FLOAT,INT"`
|
||||
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
|
||||
|
||||
## `ComfyNodeABC`
|
||||
|
||||
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
|
||||
|
||||
### Type hinting for `INPUT_TYPES`
|
||||
|
||||
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
|
||||
|
||||
### `INPUT_TYPES` return dict
|
||||
|
||||
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
|
||||
|
||||
### Options for individual inputs
|
||||
|
||||
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@ -30,3 +31,15 @@ class UnetParams(TypedDict):
|
||||
|
||||
|
||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnetWrapperFunction",
|
||||
UnetApplyConds.__name__,
|
||||
UnetParams.__name__,
|
||||
UnetApplyFunction.__name__,
|
||||
IO.__name__,
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
]
|
28
comfy/comfy_types/examples/example_nodes.py
Normal file
28
comfy/comfy_types/examples/example_nodes.py
Normal file
@ -0,0 +1,28 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
CATEGORY = "examples"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"input_int": (IO.INT, {"defaultInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("input_plus_one",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, input_int: int):
|
||||
return (input_int + 1,)
|
BIN
comfy/comfy_types/examples/input_options.png
Normal file
BIN
comfy/comfy_types/examples/input_options.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
BIN
comfy/comfy_types/examples/input_types.png
Normal file
BIN
comfy/comfy_types/examples/input_types.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
274
comfy/comfy_types/node_typing.py
Normal file
274
comfy/comfy_types/node_typing.py
Normal file
@ -0,0 +1,274 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class IO(StrEnum):
|
||||
"""Node input/output data types.
|
||||
|
||||
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
|
||||
"""
|
||||
|
||||
STRING = "STRING"
|
||||
IMAGE = "IMAGE"
|
||||
MASK = "MASK"
|
||||
LATENT = "LATENT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
GUIDER = "GUIDER"
|
||||
NOISE = "NOISE"
|
||||
CLIP = "CLIP"
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
GLIGEN = "GLIGEN"
|
||||
UPSCALE_MODEL = "UPSCALE_MODEL"
|
||||
AUDIO = "AUDIO"
|
||||
WEBCAM = "WEBCAM"
|
||||
POINT = "POINT"
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
|
||||
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
|
||||
"""
|
||||
NUMBER = "FLOAT,INT"
|
||||
"""A float or an int - could be either"""
|
||||
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
|
||||
"""Could be any of: string, float, int, or bool"""
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
"""Node description, shown as a tooltip when hovering over the node.
|
||||
|
||||
Usage::
|
||||
|
||||
# Explicitly define the description
|
||||
DESCRIPTION = "Example description here."
|
||||
|
||||
# Use the docstring of the node class.
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
"""
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
"""Defines node inputs.
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
OUTPUT_NODE: bool
|
||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
From the docs:
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
From the docs:
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
|
||||
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
|
||||
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
|
||||
|
||||
From the docs:
|
||||
|
||||
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
class CheckLazyMixin:
|
||||
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
|
||||
|
||||
def check_lazy_status(self, **kwargs) -> list[str]:
|
||||
"""Returns a list of input names that should be evaluated.
|
||||
|
||||
This basic mixin impl. requires all inputs.
|
||||
|
||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
||||
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
@ -35,6 +35,10 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@ -78,6 +82,8 @@ class ControlBase:
|
||||
self.concat_mask = False
|
||||
self.extra_concat_orig = []
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -114,6 +120,14 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
out.append(self.extra_hooks)
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
@ -129,6 +143,8 @@ class ControlBase:
|
||||
c.strength_type = self.strength_type
|
||||
c.concat_mask = self.concat_mask
|
||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
||||
c.preprocess_image = self.preprocess_image
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@ -181,7 +197,7 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||
super().__init__()
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
@ -196,11 +212,12 @@ class ControlNet(ControlBase):
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
self.concat_mask = concat_mask
|
||||
self.preprocess_image = preprocess_image
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@ -224,6 +241,7 @@ class ControlNet(ControlBase):
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||
@ -427,6 +445,7 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
|
||||
def load_controlnet_mmdit(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
@ -448,6 +467,82 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
return control
|
||||
|
||||
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_controlnet_sd35(sd, model_options={}):
|
||||
control_type = -1
|
||||
if "control_type" in sd:
|
||||
control_type = round(sd.pop("control_type").item())
|
||||
|
||||
# blur_cnet = control_type == 0
|
||||
canny_cnet = control_type == 1
|
||||
depth_cnet = control_type == 2
|
||||
|
||||
new_sd = {}
|
||||
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||
if k[1] in sd:
|
||||
new_sd[k[0]] = sd.pop(k[1])
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
sd = new_sd
|
||||
|
||||
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||
depth = y_emb_shape[0] // 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
head_dim = hidden_size // num_heads
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||
patch_size=2,
|
||||
in_chans=16,
|
||||
num_layers=num_blocks,
|
||||
main_model_double=depth,
|
||||
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||
attention_head_dim=head_dim,
|
||||
num_attention_heads=num_heads,
|
||||
adm_in_channels=2048,
|
||||
device=offload_device,
|
||||
dtype=unet_dtype,
|
||||
operations=operations)
|
||||
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
preprocess_image = lambda a: a
|
||||
if canny_cnet:
|
||||
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||
elif depth_cnet:
|
||||
preprocess_image = lambda a: 1.0 - a
|
||||
|
||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||
|
||||
@ -560,7 +655,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
@ -674,10 +772,10 @@ class T2IAdapter(ControlBase):
|
||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||
return width, height
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
|
@ -1,10 +1,9 @@
|
||||
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
|
690
comfy/hooks.py
Normal file
690
comfy/hooks.py
Normal file
@ -0,0 +1,690 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import enum
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.sd import CLIP
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return True
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
self.hook_type = hook_type
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
self.hook_id = hook_id
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
strength = self._strength_clip
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = self.weights_clip
|
||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
c._strength_model = self._strength_model
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
class HookGroup:
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
hook_kf = hook_kf.clone()
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
for t_range, keyframe in range_kfs:
|
||||
all_ranges.append(t_range)
|
||||
# turn list of ranges into boundaries
|
||||
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
||||
boundaries_set.add(0.0)
|
||||
boundaries = sorted(boundaries_set)
|
||||
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
||||
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
||||
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
||||
for t_range in real_ranges:
|
||||
hooks_schedule = []
|
||||
for hook, val in scheduled_hooks.items():
|
||||
keyframe = None
|
||||
# check if is a keyframe that works for the current t_range
|
||||
for stored_range, stored_kf in val:
|
||||
# if stored start is less than current end, then fits - give it assigned keyframe
|
||||
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
||||
keyframe = stored_kf
|
||||
break
|
||||
hooks_schedule.append((hook, keyframe))
|
||||
scheduled_keyframes.append((t_range, hooks_schedule))
|
||||
return scheduled_keyframes
|
||||
|
||||
def reset(self):
|
||||
for hook in self.hooks:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
actual.append(group)
|
||||
if len(actual) < require_count:
|
||||
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
||||
# if no hooks, then return None
|
||||
if len(actual) == 0:
|
||||
return None
|
||||
# if only 1 hook, just return itself without cloning
|
||||
elif len(actual) == 1:
|
||||
return actual[0]
|
||||
final_hook: HookGroup = None
|
||||
for hook in actual:
|
||||
if final_hook is None:
|
||||
final_hook = hook.clone()
|
||||
else:
|
||||
final_hook = final_hook.clone_and_combine(hook)
|
||||
return final_hook
|
||||
|
||||
|
||||
class HookKeyframe:
|
||||
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
||||
self.strength = strength
|
||||
# scheduling
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
class HookKeyframeGroup:
|
||||
def __init__(self):
|
||||
self.keyframes: list[HookKeyframe] = []
|
||||
self._current_keyframe: HookKeyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self._curr_t = -1.
|
||||
|
||||
# properties shadow those of HookWeightsKeyframe
|
||||
@property
|
||||
def strength(self):
|
||||
if self._current_keyframe is not None:
|
||||
return self._current_keyframe.strength
|
||||
return 1.0
|
||||
|
||||
def reset(self):
|
||||
self._current_keyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
||||
self._set_first_as_current()
|
||||
|
||||
def _set_first_as_current(self):
|
||||
if len(self.keyframes) > 0:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
eval_c = self.keyframes[i]
|
||||
# check if start_t is greater or equal to curr_t
|
||||
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
||||
if eval_c.start_t >= curr_t:
|
||||
self._current_index = i
|
||||
self._current_strength = eval_c.strength
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
self._curr_t = curr_t
|
||||
# return True if keyframe changed, False if no change
|
||||
return prev_index != self._current_index and prev_strength != self._current_strength
|
||||
|
||||
|
||||
class InterpolationMethod:
|
||||
LINEAR = "linear"
|
||||
EASE_IN = "ease_in"
|
||||
EASE_OUT = "ease_out"
|
||||
EASE_IN_OUT = "ease_in_out"
|
||||
|
||||
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
||||
|
||||
@classmethod
|
||||
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
||||
diff = num_to - num_from
|
||||
if method == cls.LINEAR:
|
||||
weights = torch.linspace(num_from, num_to, length)
|
||||
elif method == cls.EASE_IN:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * np.power(index, 2) + num_from
|
||||
elif method == cls.EASE_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
||||
elif method == cls.EASE_IN_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
||||
else:
|
||||
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
||||
if reverse:
|
||||
weights = weights.flip(dims=(0,))
|
||||
return weights
|
||||
|
||||
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
if not objects:
|
||||
return objects
|
||||
elif len(objects) <= 1:
|
||||
return [x for x in objects]
|
||||
# now that we know we have to sort, do it following these rules:
|
||||
# a) if objects have same value of attribute, maintain their relative order
|
||||
# b) perform sorting of the groups of objects with same attributes
|
||||
unique_attrs = {}
|
||||
for o in objects:
|
||||
val_attr = getattr(o, attr)
|
||||
attr_list: list = unique_attrs.get(val_attr, list())
|
||||
attr_list.append(o)
|
||||
if val_attr not in unique_attrs:
|
||||
unique_attrs[val_attr] = attr_list
|
||||
# now that we have the unique attr values grouped together in relative order, sort them by key
|
||||
sorted_attrs = dict(sorted(unique_attrs.items()))
|
||||
# now flatten out the dict into a list to return
|
||||
sorted_list = []
|
||||
for object_list in sorted_attrs.values():
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
hook.weights = lora
|
||||
return hook_group
|
||||
|
||||
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
patches_model = None
|
||||
patches_clip = None
|
||||
if weights_model is not None:
|
||||
patches_model = {}
|
||||
for key in weights_model:
|
||||
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
||||
if weights_clip is not None:
|
||||
patches_clip = {}
|
||||
for key in weights_clip:
|
||||
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
||||
hook.weights = patches_model
|
||||
hook.weights_clip = patches_clip
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
if discard_model_sampling:
|
||||
# do not include ANY model_sampling components of the model that should act as a patch
|
||||
for key in list(patches_model.keys()):
|
||||
if key.startswith("model_sampling"):
|
||||
patches_model.pop(key, None)
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook()
|
||||
hook_group.add(hook)
|
||||
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print(f"NOT LOADED {x}")
|
||||
return (new_modelpatcher, new_clip, hook_group)
|
||||
|
||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||
hooks_key = 'hooks'
|
||||
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
||||
if hooks_key not in values:
|
||||
return
|
||||
if hooks_key not in c_dict:
|
||||
hooks_value = values.get(hooks_key, None)
|
||||
if hooks_value is not None:
|
||||
c_dict[hooks_key] = hooks_value
|
||||
return
|
||||
# otherwise, need to combine with minimum duplication via cache
|
||||
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
||||
cached_hooks = cache.get(hooks_tuple, None)
|
||||
if cached_hooks is None:
|
||||
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
||||
cache[hooks_tuple] = new_hooks
|
||||
c_dict[hooks_key] = new_hooks
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
return cond
|
||||
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
||||
"end_percent": timestep_range[1]})
|
||||
|
||||
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
||||
if mask is None:
|
||||
return cond
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != 'default':
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
return conditioning_set_values(cond, {'mask': mask,
|
||||
'set_area_to_bounds': set_area_to_bounds,
|
||||
'mask_strength': strength})
|
||||
|
||||
def combine_conditioning(conds: list):
|
||||
combined_conds = []
|
||||
for cond in conds:
|
||||
combined_conds.extend(cond)
|
||||
return combined_conds
|
||||
|
||||
def combine_with_new_conds(conds: list, new_conds: list):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
||||
# finally, apply mask to conditioning and store
|
||||
final_conds.append(c)
|
||||
return final_conds
|
||||
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, masked_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
@ -175,12 +175,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if sigma_down == 0:
|
||||
x = denoised
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
@ -192,19 +194,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
@ -280,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@ -306,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
|
@ -216,3 +216,139 @@ class Mochi(LatentFormat):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
|
||||
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
|
||||
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
|
||||
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
|
||||
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
|
||||
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
|
||||
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
|
||||
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
|
||||
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
|
||||
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
|
||||
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
|
||||
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
|
||||
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
|
||||
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
|
||||
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
|
||||
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
|
||||
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
|
||||
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
|
||||
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
|
||||
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
|
||||
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
|
||||
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
|
||||
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
|
||||
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
|
||||
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
|
||||
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
|
||||
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
|
||||
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
|
||||
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
|
||||
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
|
||||
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
|
||||
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
|
||||
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
|
||||
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
|
||||
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
|
||||
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
|
||||
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
|
||||
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
|
||||
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
|
||||
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
|
||||
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
|
||||
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
|
||||
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
|
||||
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
|
||||
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
|
||||
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
|
||||
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
|
||||
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
|
||||
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
|
||||
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
|
||||
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
|
||||
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
|
||||
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
|
||||
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
|
||||
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
|
||||
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
|
||||
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
|
||||
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
|
||||
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
|
||||
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
|
||||
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
|
||||
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
|
||||
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
|
||||
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
|
||||
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
|
||||
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
|
||||
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
|
||||
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
|
||||
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
|
||||
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
|
||||
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
|
||||
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
|
||||
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
|
||||
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
|
||||
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
|
||||
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
|
||||
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
|
||||
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
|
||||
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
|
||||
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
|
||||
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
|
||||
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
|
||||
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
|
||||
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
|
||||
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
|
||||
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
|
||||
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
|
||||
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
|
||||
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
|
||||
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
|
||||
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
|
||||
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
|
||||
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
|
||||
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
|
||||
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
|
||||
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
|
||||
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
|
||||
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
|
||||
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
|
||||
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
|
||||
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
|
||||
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
|
||||
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
|
||||
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
|
||||
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
|
||||
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
|
||||
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
|
||||
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
|
||||
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
|
||||
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
|
||||
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
|
||||
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
|
||||
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
|
||||
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
|
||||
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
|
||||
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
|
||||
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
|
||||
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
|
||||
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
|
||||
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
|
||||
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
|
||||
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
|
||||
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
|
||||
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
|
||||
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
|
||||
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
|
||||
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Literal, Dict, Any
|
||||
from typing import Literal
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, einsum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from torch import Tensor
|
||||
from typing import List, Union
|
||||
from einops import rearrange
|
||||
import math
|
||||
import comfy.ops
|
||||
|
@ -16,7 +16,6 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from .common import LayerNorm2d_op
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
|
@ -6,9 +6,7 @@ import math
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
from .layers import (timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@ -29,6 +30,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@ -43,8 +45,9 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@ -165,7 +168,7 @@ class Flux(nn.Module):
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
@ -1,7 +1,7 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -1,7 +1,7 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
|
@ -1,24 +1,17 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||
Mlp,
|
||||
TimestepEmbedder,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
|
@ -1,8 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
@ -287,7 +285,7 @@ class HunYuanDiT(nn.Module):
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
@ -315,8 +313,7 @@ class HunYuanDiT(nn.Module):
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
@ -364,6 +361,8 @@ class HunYuanDiT(nn.Module):
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
@ -375,9 +374,20 @@ class HunYuanDiT(nn.Module):
|
||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
skip = None
|
||||
|
||||
if ("double_block", layer) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
|
527
comfy/ldm/lightricks/model.py
Normal file
527
comfy/ldm/lightricks/model.py
Normal file
@ -0,0 +1,527 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
# else:
|
||||
# self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
Reference:
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
if guiding_latent_noise_scale > 0:
|
||||
if self.generator is None:
|
||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||
elif self.generator.device != x.device:
|
||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||
|
||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||
|
||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
||||
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
def patchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
h=output_height,
|
||||
w=output_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
@ -0,0 +1,64 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
dilation = (dilation, 1, 1)
|
||||
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.conv.weight
|
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
@ -0,0 +1,698 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.blocks_desc = blocks
|
||||
|
||||
in_channels = in_channels * patch_size**2
|
||||
output_channel = base_channels
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in blocks:
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# out
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == "per_channel":
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
num_dims = sample.dim()
|
||||
|
||||
if num_dims == 4:
|
||||
# For shape (B, C, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
elif num_dims == 5:
|
||||
# For shape (B, C, F, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
causal (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use causal convolutions or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.layers_per_block = layers_per_block
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.blocks_desc = blocks
|
||||
|
||||
# Compute output channel to be product of all channel-multiplier blocks
|
||||
output_channel = base_channels
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
|
||||
self.up_blocks.append(block)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
) -> torch.FloatTensor:
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm1 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
self.conv1 = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm2 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = make_conv_nd(
|
||||
dims,
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm3 = (
|
||||
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
|
82
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
82
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def make_conv_nd(
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def make_linear_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias=True,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
elif dims == 3 or dims == (2, 1):
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
@ -0,0 +1,195 @@
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class DualConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if kernel_size == (1, 1, 1):
|
||||
raise ValueError(
|
||||
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
||||
)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
# Set parameters for convolutions
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
|
||||
# Define the size of the channels after the first convolution
|
||||
intermediate_channels = (
|
||||
out_channels if in_channels < out_channels else in_channels
|
||||
)
|
||||
|
||||
# Define parameters for the first convolution
|
||||
self.weight1 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
intermediate_channels,
|
||||
in_channels // groups,
|
||||
1,
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
)
|
||||
)
|
||||
self.stride1 = (1, stride[1], stride[2])
|
||||
self.padding1 = (0, padding[1], padding[2])
|
||||
self.dilation1 = (1, dilation[1], dilation[2])
|
||||
if bias:
|
||||
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
||||
else:
|
||||
self.register_parameter("bias1", None)
|
||||
|
||||
# Define parameters for the second convolution
|
||||
self.weight2 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
||||
)
|
||||
)
|
||||
self.stride2 = (stride[0], 1, 1)
|
||||
self.padding2 = (padding[0], 0, 0)
|
||||
self.dilation2 = (dilation[0], 1, 1)
|
||||
if bias:
|
||||
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias2", None)
|
||||
|
||||
# Initialize weights and biases
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
||||
bound1 = 1 / math.sqrt(fan_in1)
|
||||
nn.init.uniform_(self.bias1, -bound1, bound1)
|
||||
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
||||
bound2 = 1 / math.sqrt(fan_in2)
|
||||
nn.init.uniform_(self.bias2, -bound2, bound2)
|
||||
|
||||
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
||||
if use_conv3d:
|
||||
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
||||
else:
|
||||
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
||||
|
||||
def forward_with_3d(self, x, skip_time_conv):
|
||||
# First convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight1,
|
||||
self.bias1,
|
||||
self.stride1,
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
self.stride2,
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# First 2D convolution
|
||||
x = rearrange(x, "b c d h w -> (b d) c h w")
|
||||
# Squeeze the depth dimension out of weight1 since it's 1
|
||||
weight1 = self.weight1.squeeze(2)
|
||||
# Select stride, padding, and dilation for the 2D convolution
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
|
||||
# Reshape weight2 to match the expected dimensions for conv1d
|
||||
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
||||
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.weight2
|
||||
|
||||
|
||||
def test_dual_conv3d_consistency():
|
||||
# Initialize parameters
|
||||
in_channels = 3
|
||||
out_channels = 5
|
||||
kernel_size = (3, 3, 3)
|
||||
stride = (2, 2, 2)
|
||||
padding = (1, 1, 1)
|
||||
|
||||
# Create an instance of the DualConv3d class
|
||||
dual_conv3d = DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Example input tensor
|
||||
test_input = torch.randn(1, 3, 10, 10, 10)
|
||||
|
||||
# Perform forward passes with both 3D and 2D settings
|
||||
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
||||
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
||||
|
||||
# Assert that the outputs from both methods are sufficiently close
|
||||
assert torch.allclose(
|
||||
output_conv3d, output_2d, atol=1e-6
|
||||
), "Outputs are not consistent between 3D and 2D convolutions."
|
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, dim=1, eps=1e-8):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
|
@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
if mask.shape[1] == 1:
|
||||
s1 += mask
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
import numpy as np
|
||||
|
@ -3,7 +3,6 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Optional, Any
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
|
@ -9,12 +9,12 @@ import logging
|
||||
from .util import (
|
||||
checkpoint,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@ -47,6 +47,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@ -819,6 +828,13 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
|
@ -4,7 +4,6 @@ import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from .util import extract_into_tensor, make_beta_schedule
|
||||
from comfy.ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
|
@ -8,7 +8,6 @@
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -234,6 +234,8 @@ def efficient_dot_product_attention(
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
if mask.shape[1] == 1:
|
||||
return mask
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import functools
|
||||
from typing import Callable, Iterable, Union
|
||||
from typing import Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
@ -33,7 +33,7 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
def load_lora(lora, to_load, log_missing=True):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@ -49,10 +49,20 @@ def load_lora(lora, to_load):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
@ -72,6 +82,10 @@ def load_lora(lora, to_load):
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
@ -82,7 +96,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@ -193,9 +207,16 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
@ -282,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@ -344,6 +368,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@ -400,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
@ -440,10 +470,22 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "model_as_lora":
|
||||
target_weight: torch.Tensor = v[0]
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
|
17
comfy/lora_convert.py
Normal file
17
comfy/lora_convert.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
return sd
|
@ -30,14 +30,19 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@ -94,6 +99,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
@ -119,6 +125,13 @@ class BaseModel(torch.nn.Module):
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._apply_model,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
|
||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
@ -153,8 +166,7 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
def concat_cond(self, **kwargs):
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
@ -193,7 +205,14 @@ class BaseModel(torch.nn.Module):
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
return data
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
if concat_cond is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
@ -523,9 +542,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
@ -537,18 +554,15 @@ class IP2P:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -709,6 +723,44 @@ class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
#Handle Flux control loras dynamically changing the img_in weight.
|
||||
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||
except:
|
||||
#Some cases like tensorrt might not have the weights accessible
|
||||
num_channels = self.model_config.unet_config["in_channels"]
|
||||
|
||||
out_channels = self.model_config.unet_config["out_channels"]
|
||||
|
||||
if num_channels <= out_channels:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = self.process_latent_in(image)
|
||||
if num_channels <= out_channels * 2:
|
||||
return image
|
||||
|
||||
#inpaint model
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
@ -734,3 +786,27 @@ class GenmoMochi(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guiding_latent = kwargs.get("guiding_latent", None)
|
||||
if guiding_latent is not None:
|
||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||
|
||||
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
|
||||
if guiding_latent_noise_scale is not None:
|
||||
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
return out
|
||||
|
@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
@ -177,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
@ -23,6 +23,8 @@ from comfy.cli_args import args
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -287,11 +289,27 @@ def module_size(module):
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
self.real_model = None
|
||||
self.currently_used = True
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model()
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@ -306,32 +324,23 @@ class LoadedModel:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
# if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
real_model = self.model.model
|
||||
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||
@ -344,18 +353,26 @@ class LoadedModel:
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
def model_use_more_vram(self, extra_memory):
|
||||
return self.model.partially_load(self.device, extra_memory)
|
||||
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
|
||||
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def __del__(self):
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
@ -386,38 +403,8 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -425,7 +412,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
@ -454,6 +441,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
@ -466,11 +454,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
models = set(models)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except:
|
||||
@ -478,51 +464,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
|
||||
if loaded is None:
|
||||
loaded.currently_used = True
|
||||
models_to_load.append(loaded)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem < minimum_memory_required:
|
||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||
models_to_load = free_memory(minimum_memory_required, d)
|
||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||
else:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
if len(models_to_load) == 0:
|
||||
return
|
||||
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
for loaded_model in models_to_load:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
for i in to_unload:
|
||||
current_loaded_models.pop(i).model.detach(unpatch_all=False)
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
@ -544,17 +514,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem > minimum_memory_required:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
return
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@ -568,21 +529,35 @@ def loaded_models(only_currently_used=False):
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
|
||||
if do_gc:
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
#TODO: very fragile function needs improvement
|
||||
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||
if num_refs <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||
to_delete = [i] + to_delete
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
@ -628,6 +603,10 @@ def maximum_vram_for_weights(device=None):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if model_params < 0:
|
||||
model_params = 1000000000000000000000
|
||||
if args.fp32_unet:
|
||||
return torch.float32
|
||||
if args.fp64_unet:
|
||||
return torch.float64
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@ -674,7 +653,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -243,7 +243,7 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
return time_snr_shift(self.shift, 1.0 - percent)
|
||||
|
||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
def __init__(self, model_config=None):
|
||||
@ -336,4 +336,4 @@ class ModelSamplingFlux(torch.nn.Module):
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
||||
|
@ -269,7 +269,7 @@ def fp8_linear(self, input):
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||
|
156
comfy/patcher_extension.py
Normal file
156
comfy/patcher_extension.py
Normal file
@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
ON_PRE_RUN = "on_pre_run"
|
||||
ON_PREPARE_STATE = "on_prepare_state"
|
||||
ON_APPLY_HOOKS = "on_apply_hooks"
|
||||
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
||||
ON_INJECT_MODEL = "on_inject_model"
|
||||
ON_EJECT_MODEL = "on_eject_model"
|
||||
|
||||
# callbacks dict is in the format:
|
||||
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
||||
|
||||
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
||||
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
||||
return c_list
|
||||
|
||||
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
for c in callbacks.get(call_type, {}).values():
|
||||
c_list.extend(c)
|
||||
return c_list
|
||||
|
||||
class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
# wrappers dict is in the format:
|
||||
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
||||
|
||||
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
||||
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
||||
return w_list
|
||||
|
||||
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
for w in wrappers.get(wrapper_type, {}).values():
|
||||
w_list.extend(w)
|
||||
return w_list
|
||||
|
||||
class WrapperExecutor:
|
||||
"""Handles call stack of wrappers around a function in an ordered manner."""
|
||||
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
||||
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
||||
# the class instance at runtime via executor.class_obj
|
||||
self.original = original
|
||||
self.class_obj = class_obj
|
||||
self.wrappers = wrappers.copy()
|
||||
self.idx = idx
|
||||
self.is_last = idx == len(wrappers)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||
new_executor = self._create_next_executor()
|
||||
return new_executor.execute(*args, **kwargs)
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||
args = list(args)
|
||||
kwargs = dict(kwargs)
|
||||
if self.is_last:
|
||||
return self.original(*args, **kwargs)
|
||||
return self.wrappers[self.idx](self, *args, **kwargs)
|
||||
|
||||
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||
new_idx = self.idx + 1
|
||||
if new_idx > len(self.wrappers):
|
||||
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||
if self.class_obj is None:
|
||||
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||
|
||||
@classmethod
|
||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||
|
||||
@classmethod
|
||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj, wrappers, idx=idx)
|
||||
|
||||
class PatcherInjection:
|
||||
def __init__(self, inject: Callable, eject: Callable):
|
||||
self.inject = inject
|
||||
self.eject = eject
|
||||
|
||||
def copy_nested_dicts(input_dict: dict):
|
||||
new_dict = input_dict.copy()
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
new_dict[key] = copy_nested_dicts(value)
|
||||
elif isinstance(value, list):
|
||||
new_dict[key] = value.copy()
|
||||
return new_dict
|
||||
|
||||
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
if copy_dict1:
|
||||
merged_dict = copy_nested_dicts(dict1)
|
||||
else:
|
||||
merged_dict = dict1
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
merged_dict[key] = value
|
||||
return merged_dict
|
@ -1,7 +1,15 @@
|
||||
import torch
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||
@ -10,9 +18,43 @@ def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
if isinstance(c[model_type], list):
|
||||
models += c[model_type]
|
||||
else:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||
cnets: list[ControlBase] = []
|
||||
for c in cond:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook: comfy.hooks.Hook
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
if 'control' in c:
|
||||
cnets.append(c['control'])
|
||||
|
||||
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
|
||||
if cnet.extra_hooks is not None:
|
||||
_list.append(cnet.extra_hooks)
|
||||
if cnet.previous_controlnet is None:
|
||||
return _list
|
||||
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||
|
||||
hooks_list = []
|
||||
cnets = set(cnets)
|
||||
for base_cnet in cnets:
|
||||
get_extra_hooks_from_cnet(base_cnet, hooks_list)
|
||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||
if extra_hooks is not None:
|
||||
for hook in extra_hooks.hooks:
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
|
||||
return hooks_dict
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
@ -22,17 +64,22 @@ def convert_cond(cond):
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
temp["uuid"] = uuid.uuid4()
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
cnets: list[ControlBase] = []
|
||||
gligen = []
|
||||
add_models = []
|
||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
@ -43,7 +90,9 @@ def get_additional_models(conds, dtype):
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||
models = control_models + gligen + add_models + hook_models
|
||||
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
@ -53,10 +102,11 @@ def cleanup_additional_models(models):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model, noise_shape, conds):
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
real_model: 'BaseModel' = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
@ -72,3 +122,14 @@ def cleanup_models(conds, models):
|
||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
# check for hooks in conds - if not registered, see if can be applied
|
||||
hooks = {}
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
# register hooks on model/model_options
|
||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||
|
@ -1,11 +1,21 @@
|
||||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
|
||||
patches = None
|
||||
@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@ -138,110 +149,184 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
||||
# need to figure out remaining unmasked area for conds
|
||||
default_mults = []
|
||||
for _ in default_conds:
|
||||
default_mults.append(torch.ones_like(x_in))
|
||||
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
|
||||
for lora_hooks, to_run in hooked_to_run.items():
|
||||
for cond_obj, i in to_run:
|
||||
# if no default_cond for cond_type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
area: list[int] = cond_obj.area
|
||||
if area is not None:
|
||||
curr_default_mult: torch.Tensor = default_mults[i]
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
|
||||
curr_default_mult -= cond_obj.mult
|
||||
else:
|
||||
default_mults[i] -= cond_obj.mult
|
||||
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
|
||||
for i, mult in enumerate(default_mults):
|
||||
# if no default_cond for cond type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
torch.nn.functional.relu(mult, inplace=True)
|
||||
# if mult is all zeros, then don't add default_cond
|
||||
if torch.max(mult) == 0.0:
|
||||
continue
|
||||
|
||||
cond = default_conds[i]
|
||||
for x in cond:
|
||||
# do get_area_and_mult to get all the expected values
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
# replace p's mult with calculated mult
|
||||
p = p._replace(mult=mult)
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
to_run = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
to_run += [(p, i)]
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@ -261,7 +346,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
@ -500,10 +585,15 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
# handle clip hook schedule, if needed
|
||||
if 'clip_start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
|
||||
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
|
||||
else:
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x.copy()
|
||||
@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook.initialize_timesteps(model)
|
||||
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
return conds
|
||||
|
||||
|
||||
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
||||
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
control: 'ControlBase' = kk['control']
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
to_replace = hook_replacement.setdefault((control, hooks), [])
|
||||
to_replace.append(kk)
|
||||
# if nothing to replace, do nothing
|
||||
if len(hook_replacement) == 0:
|
||||
return
|
||||
|
||||
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
||||
# on the cond dicts
|
||||
for key, conds_to_modify in hook_replacement.items():
|
||||
control = key[0]
|
||||
hooks = key[1]
|
||||
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
||||
# if combined hooks are not None, set as new hooks for all relevant conds
|
||||
if hooks is not None:
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks_set.add(kk.get('hooks', None))
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
@ -714,19 +847,17 @@ class CFGGuider:
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
||||
|
||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
sampler.sample,
|
||||
sampler,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
||||
)
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@ -737,14 +868,48 @@ class CFGGuider:
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.conds
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
preprocess_conds_hooks(self.conds)
|
||||
|
||||
try:
|
||||
orig_model_options = self.model_options
|
||||
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
||||
orig_hook_mode = self.model_patcher.hook_mode
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
return output
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
cfg_guider = CFGGuider(model)
|
||||
|
103
comfy/sd.py
103
comfy/sd.py
@ -1,13 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@ -27,12 +30,17 @@ import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.long_clipl
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@ -40,6 +48,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
@ -92,9 +101,13 @@ class CLIP:
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
@ -103,6 +116,8 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.use_clip_schedule = self.use_clip_schedule
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
@ -114,6 +129,69 @@ class CLIP:
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
if self.apply_hooks_to_conds:
|
||||
pooled_dict["hooks"] = self.apply_hooks_to_conds
|
||||
return pooled_dict
|
||||
|
||||
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
|
||||
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
|
||||
all_hooks = self.patcher.forced_hooks
|
||||
if all_hooks is None or not self.use_clip_schedule:
|
||||
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
|
||||
return_pooled = "unprojected" if unprojected else True
|
||||
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
|
||||
cond = pooled_dict.pop("cond")
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
else:
|
||||
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
|
||||
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
@ -131,6 +209,7 @@ class CLIP:
|
||||
if len(o) > 2:
|
||||
for k in o[2]:
|
||||
out[k] = o[2][k]
|
||||
self.add_hooks_to_dict(out)
|
||||
return out
|
||||
|
||||
if return_pooled:
|
||||
@ -257,6 +336,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -356,7 +443,9 @@ class VAE:
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -420,6 +509,12 @@ class VAE:
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
try:
|
||||
return self.upscale_ratio[-1]
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
@ -433,6 +528,8 @@ def load_style_model(ckpt_path):
|
||||
keys = model_data.keys()
|
||||
if "style_embedding" in keys:
|
||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||
elif "redux_down.weight" in keys:
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
@ -446,6 +543,7 @@ class CLIPType(Enum):
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
LTXV = 8
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
@ -524,6 +622,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
|
@ -90,8 +90,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
if isinstance(textmodel_json_config, dict):
|
||||
config = textmodel_json_config
|
||||
else:
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
@ -196,11 +199,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
end_token = self.special_tokens.get("end", -1)
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == end_token:
|
||||
if tokens[x, y] == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[x, y] = 0
|
||||
break
|
||||
|
||||
attention_mask_model = None
|
||||
@ -411,22 +421,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
if has_end_token:
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
if has_end_token:
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
@ -451,13 +464,16 @@ class SDTokenizer:
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
split_embed = embedding_name.split(' ')
|
||||
embedding_name = split_embed[0]
|
||||
leftover = ' '.join(split_embed[1:])
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@ -474,7 +490,12 @@ class SDTokenizer:
|
||||
#tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ")
|
||||
split = to_tokenize.split(' {}'.format(self.embedding_identifier))
|
||||
to_tokenize = [split[0]]
|
||||
for i in range(1, len(split)):
|
||||
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
#if we find an embedding, deal with the embedding
|
||||
@ -493,8 +514,11 @@ class SDTokenizer:
|
||||
word = leftover
|
||||
else:
|
||||
continue
|
||||
end = 999999999999
|
||||
if self.end_token is not None:
|
||||
end = -1
|
||||
#parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
|
||||
|
||||
#reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
@ -505,18 +529,24 @@ class SDTokenizer:
|
||||
for i, t_group in enumerate(tokens):
|
||||
#determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
if self.end_token is not None:
|
||||
has_end_token = 1
|
||||
else:
|
||||
has_end_token = 0
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
if len(t_group) + len(batch) > self.max_length - has_end_token:
|
||||
remaining_length = self.max_length - len(batch) - has_end_token
|
||||
#break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
#add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
@ -529,7 +559,8 @@ class SDTokenizer:
|
||||
t_group = []
|
||||
|
||||
#fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
|
@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -658,6 +659,15 @@ class Flux(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class FluxInpaint(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": True,
|
||||
"in_channels": 96,
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@ -702,7 +712,34 @@ class GenmoMochi(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||
|
||||
class LTXV(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ltxv",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||
sampling_settings = {
|
||||
"shift": 2.37,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.LTXV
|
||||
|
||||
memory_usage_factor = 2.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
18
comfy/text_encoders/lt.py
Normal file
18
comfy/text_encoders/lt.py
Normal file
@ -0,0 +1,18 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.text_encoders.genmo
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
class SPieceTokenizer:
|
||||
|
@ -209,6 +209,11 @@ class T5Stack(torch.nn.Module):
|
||||
intermediate = None
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
||||
past_bias = None
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.block) + intermediate_output
|
||||
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, mask, past_bias, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
|
@ -46,7 +46,13 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
sd = pl_sd
|
||||
if len(pl_sd) == 1:
|
||||
key = list(pl_sd.keys())[0]
|
||||
sd = pl_sd[key]
|
||||
if not isinstance(sd, dict):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
return sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
39
comfy_execution/validation.py
Normal file
39
comfy_execution/validation.py
Normal file
@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def validate_node_input(
|
||||
received_type: str, input_type: str, strict: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
received_type and input_type are both strings of the form "T1,T2,...".
|
||||
|
||||
If strict is True, the input_type must contain the received_type.
|
||||
For example, if received_type is "STRING" and input_type is "STRING,INT",
|
||||
this will return True. But if received_type is "STRING,INT" and input_type is
|
||||
"INT", this will return False.
|
||||
|
||||
If strict is False, the input_type must have overlap with the received_type.
|
||||
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
|
||||
this will return True.
|
||||
|
||||
Supports pre-union type extension behaviour of ``__ne__`` overrides.
|
||||
"""
|
||||
# If the types are exactly the same, we can return immediately
|
||||
# Use pre-union behaviour: inverse of `__ne__`
|
||||
if not received_type != input_type:
|
||||
return True
|
||||
|
||||
# Not equal, and not strings
|
||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||
return False
|
||||
|
||||
# Split the type strings into sets for comparison
|
||||
received_types = set(t.strip() for t in received_type.split(","))
|
||||
input_types = set(t.strip() for t in input_type.split(","))
|
||||
|
||||
if strict:
|
||||
# In strict mode, all received types must be in the input types
|
||||
return received_types.issubset(input_types)
|
||||
else:
|
||||
# In non-strict mode, there must be at least one type in common
|
||||
return len(received_types.intersection(input_types)) > 0
|
@ -2,8 +2,7 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm.auto import trange, tqdm
|
||||
import math
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -1,4 +1,3 @@
|
||||
import torch
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodeSDXLRefiner:
|
||||
@ -17,8 +16,7 @@ class CLIPTextEncodeSDXLRefiner:
|
||||
|
||||
def encode(self, clip, ascore, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
|
||||
|
||||
class CLIPTextEncodeSDXL:
|
||||
@classmethod
|
||||
@ -47,8 +45,7 @@ class CLIPTextEncodeSDXL:
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import comfy.utils
|
||||
from enum import Enum
|
||||
|
@ -18,10 +18,7 @@ class CLIPTextEncodeFlux:
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
output["guidance"] = guidance
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
||||
|
||||
class FluxGuidance:
|
||||
@classmethod
|
||||
|
744
comfy_extras/nodes_hooks.py
Normal file
744
comfy_extras/nodes_hooks.py
Normal file
@ -0,0 +1,744 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
import torch
|
||||
from collections.abc import Iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.sd import CLIP
|
||||
|
||||
import comfy.hooks
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
|
||||
###########################################
|
||||
# Mask, Combine, and Hook Conditioning
|
||||
#------------------------------------------
|
||||
class PairConditioningSetProperties:
|
||||
NodeId = 'PairConditioningSetProperties'
|
||||
NodeName = 'Cond Pair Set Props'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive_NEW": ("CONDITIONING", ),
|
||||
"negative_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class PairConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Pair Set Props Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"positive_NEW": ("CONDITIONING", ),
|
||||
"negative_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetProperties:
|
||||
NodeId = 'ConditioningSetProperties'
|
||||
NodeName = 'Cond Set Props'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class ConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'ConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Set Props Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond": ("CONDITIONING", ),
|
||||
"cond_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, cond, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class PairConditioningCombine:
|
||||
NodeId = 'PairConditioningCombine'
|
||||
NodeName = 'Cond Pair Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive_A": ("CONDITIONING",),
|
||||
"negative_A": ("CONDITIONING",),
|
||||
"positive_B": ("CONDITIONING",),
|
||||
"negative_B": ("CONDITIONING",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, positive_A, negative_A, positive_B, negative_B):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
|
||||
return (final_positive, final_negative,)
|
||||
|
||||
class PairConditioningSetDefaultAndCombine:
|
||||
NodeId = 'PairConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Pair Set Default Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"positive_DEFAULT": ("CONDITIONING",),
|
||||
"negative_DEFAULT": ("CONDITIONING",),
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_default_and_combine"
|
||||
|
||||
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
|
||||
hooks: comfy.hooks.HookGroup=None):
|
||||
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetDefaultAndCombine:
|
||||
NodeId = 'ConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Set Default Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond": ("CONDITIONING",),
|
||||
"cond_DEFAULT": ("CONDITIONING",),
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_default_and_combine"
|
||||
|
||||
def set_default_and_combine(self, cond, cond_DEFAULT,
|
||||
hooks: comfy.hooks.HookGroup=None):
|
||||
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_conditioning,)
|
||||
|
||||
class SetClipHooks:
|
||||
NodeId = 'SetClipHooks'
|
||||
NodeName = 'Set CLIP Hooks'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP",),
|
||||
"apply_to_conds": ("BOOLEAN", {"default": True}),
|
||||
"schedule_clip": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
CATEGORY = "advanced/hooks/clip"
|
||||
FUNCTION = "apply_hooks"
|
||||
|
||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
if apply_to_conds:
|
||||
clip.apply_hooks_to_conds = hooks
|
||||
clip.patcher.forced_hooks = hooks.clone()
|
||||
clip.use_clip_schedule = schedule_clip
|
||||
if not clip.use_clip_schedule:
|
||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
||||
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
||||
CATEGORY = "advanced/hooks"
|
||||
FUNCTION = "create_range"
|
||||
|
||||
def create_range(self, start_percent: float, end_percent: float):
|
||||
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
###########################################
|
||||
# Create Hooks
|
||||
#------------------------------------------
|
||||
class CreateHookLora:
|
||||
NodeId = 'CreateHookLora'
|
||||
NodeName = 'Create Hook LoRA'
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook"
|
||||
|
||||
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||
if prev_hooks is None:
|
||||
prev_hooks = comfy.hooks.HookGroup()
|
||||
prev_hooks.clone()
|
||||
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (prev_hooks,)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
temp = self.loaded_lora
|
||||
self.loaded_lora = None
|
||||
del temp
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
|
||||
return (prev_hooks.clone_and_combine(hooks),)
|
||||
|
||||
class CreateHookLoraModelOnly(CreateHookLora):
|
||||
NodeId = 'CreateHookLoraModelOnly'
|
||||
NodeName = 'Create Hook LoRA (MO)'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook_model_only"
|
||||
|
||||
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
|
||||
|
||||
class CreateHookModelAsLora:
|
||||
NodeId = 'CreateHookModelAsLora'
|
||||
NodeName = 'Create Hook Model as LoRA'
|
||||
|
||||
def __init__(self):
|
||||
# when not None, will be in following format:
|
||||
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
|
||||
self.loaded_weights = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook"
|
||||
|
||||
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
|
||||
prev_hooks: comfy.hooks.HookGroup=None):
|
||||
if prev_hooks is None:
|
||||
prev_hooks = comfy.hooks.HookGroup()
|
||||
prev_hooks.clone()
|
||||
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
weights_model = None
|
||||
weights_clip = None
|
||||
if self.loaded_weights is not None:
|
||||
if self.loaded_weights[0] == ckpt_path:
|
||||
weights_model = self.loaded_weights[1]
|
||||
weights_clip = self.loaded_weights[2]
|
||||
else:
|
||||
temp = self.loaded_weights
|
||||
self.loaded_weights = None
|
||||
del temp
|
||||
|
||||
if weights_model is None:
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
||||
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
|
||||
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
|
||||
|
||||
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
|
||||
strength_model=strength_model, strength_clip=strength_clip)
|
||||
return (prev_hooks.clone_and_combine(hooks),)
|
||||
|
||||
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
||||
NodeId = 'CreateHookModelAsLoraModelOnly'
|
||||
NodeName = 'Create Hook Model as LoRA (MO)'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook_model_only"
|
||||
|
||||
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
|
||||
prev_hooks: comfy.hooks.HookGroup=None):
|
||||
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
###########################################
|
||||
# Schedule Hooks
|
||||
#------------------------------------------
|
||||
class SetHookKeyframes:
|
||||
NodeId = 'SetHookKeyframes'
|
||||
NodeName = 'Set Hook Keyframes'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
"optional": {
|
||||
"hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "set_hook_keyframes"
|
||||
|
||||
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if hook_kf is not None:
|
||||
hooks = hooks.clone()
|
||||
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframe"
|
||||
|
||||
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
|
||||
prev_hook_kf.add(keyframe)
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
||||
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
||||
"interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
|
||||
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframes"
|
||||
|
||||
def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
|
||||
start_percent: float, end_percent: float, keyframes_count: int,
|
||||
print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, strengths):
|
||||
guarantee_steps = 0
|
||||
if is_first:
|
||||
guarantee_steps = 1
|
||||
is_first = False
|
||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||
if print_keyframes:
|
||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframes"
|
||||
|
||||
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
|
||||
start_percent: float, end_percent: float,
|
||||
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
if type(floats_strength) in (float, int):
|
||||
floats_strength = [float(floats_strength)]
|
||||
elif isinstance(floats_strength, Iterable):
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, floats_strength):
|
||||
guarantee_steps = 0
|
||||
if is_first:
|
||||
guarantee_steps = 1
|
||||
is_first = False
|
||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||
if print_keyframes:
|
||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||
return (prev_hook_kf,)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
class SetModelHooksOnCond:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/manual"
|
||||
FUNCTION = "attach_hook"
|
||||
|
||||
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
|
||||
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
|
||||
|
||||
|
||||
###########################################
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
|
||||
class CombineHooksFour:
|
||||
NodeId = 'CombineHooks4'
|
||||
NodeName = 'Combine Hooks [4]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
"hooks_C": ("HOOKS",),
|
||||
"hooks_D": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None,
|
||||
hooks_C: comfy.hooks.HookGroup=None,
|
||||
hooks_D: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
|
||||
class CombineHooksEight:
|
||||
NodeId = 'CombineHooks8'
|
||||
NodeName = 'Combine Hooks [8]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
"hooks_C": ("HOOKS",),
|
||||
"hooks_D": ("HOOKS",),
|
||||
"hooks_E": ("HOOKS",),
|
||||
"hooks_F": ("HOOKS",),
|
||||
"hooks_G": ("HOOKS",),
|
||||
"hooks_H": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None,
|
||||
hooks_C: comfy.hooks.HookGroup=None,
|
||||
hooks_D: comfy.hooks.HookGroup=None,
|
||||
hooks_E: comfy.hooks.HookGroup=None,
|
||||
hooks_F: comfy.hooks.HookGroup=None,
|
||||
hooks_G: comfy.hooks.HookGroup=None,
|
||||
hooks_H: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
node_list = [
|
||||
# Create
|
||||
CreateHookLora,
|
||||
CreateHookLoraModelOnly,
|
||||
CreateHookModelAsLora,
|
||||
CreateHookModelAsLoraModelOnly,
|
||||
# Scheduling
|
||||
SetHookKeyframes,
|
||||
CreateHookKeyframe,
|
||||
CreateHookKeyframesInterpolated,
|
||||
CreateHookKeyframesFromFloats,
|
||||
# Combine
|
||||
CombineHooks,
|
||||
CombineHooksFour,
|
||||
CombineHooksEight,
|
||||
# Attach
|
||||
ConditioningSetProperties,
|
||||
ConditioningSetPropertiesAndCombine,
|
||||
PairConditioningSetProperties,
|
||||
PairConditioningSetPropertiesAndCombine,
|
||||
ConditioningSetDefaultAndCombine,
|
||||
PairConditioningSetDefaultAndCombine,
|
||||
PairConditioningCombine,
|
||||
SetClipHooks,
|
||||
# Other
|
||||
ConditioningTimestepsRange,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
for node in node_list:
|
||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
@ -15,9 +15,7 @@ class CLIPTextEncodeHunyuanDiT:
|
||||
tokens = clip.tokenize(bert)
|
||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
184
comfy_extras/nodes_lt.py
Normal file
184
comfy_extras/nodes_lt.py
Normal file
@ -0,0 +1,184 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
|
||||
class EmptyLTXVLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/video/ltxv"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVImgToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
||||
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
return (positive, negative, {"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def append(self, positive, negative, frame_rate):
|
||||
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class ModelSamplingLTXV:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, max_shift, base_shift, latent=None):
|
||||
m = model.clone()
|
||||
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (tokens) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
class LTXVScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"stretch": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
|
||||
}),
|
||||
"terminal": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
|
||||
"tooltip": "The terminal value of the sigmas after stretching."
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = (tokens) * mm + b
|
||||
|
||||
power = 1
|
||||
sigmas = torch.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
0,
|
||||
)
|
||||
|
||||
# Stretch sigmas so that its final value matches the given terminal value.
|
||||
if stretch:
|
||||
non_zero_mask = sigmas != 0
|
||||
non_zero_sigmas = sigmas[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
sigmas[non_zero_mask] = stretched
|
||||
|
||||
return (sigmas,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||
"LTXVImgToVideo": LTXVImgToVideo,
|
||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||
"LTXVConditioning": LTXVConditioning,
|
||||
"LTXVScheduler": LTXVScheduler,
|
||||
}
|
41
comfy_extras/nodes_mahiro.py
Normal file
41
comfy_extras/nodes_mahiro.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Mahiro:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "_for_testing"
|
||||
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
def mahiro_normd(args):
|
||||
scale: float = args['cond_scale']
|
||||
cond_p: torch.Tensor = args['cond_denoised']
|
||||
uncond_p: torch.Tensor = args['uncond_denoised']
|
||||
#naive leap
|
||||
leap = cond_p * scale
|
||||
#sim with uncond leap
|
||||
u_leap = uncond_p * scale
|
||||
cfg = args["denoised"]
|
||||
merge = (leap + cfg) / 2
|
||||
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
|
||||
normm = torch.sqrt(merge.abs()) * merge.sign()
|
||||
sim = F.cosine_similarity(normu, normm).mean()
|
||||
simsc = 2 * (sim+1)
|
||||
wm = (simsc*cfg + (4-simsc)*leap) / 4
|
||||
return wm
|
||||
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Mahiro": Mahiro
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
}
|
@ -1,4 +1,3 @@
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_sampling
|
||||
import comfy.latent_formats
|
||||
|
@ -1,4 +1,3 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
class PatchModelAddDownscale:
|
||||
|
@ -174,6 +174,28 @@ class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["patchify_proj."] = argument
|
||||
arg_dict["adaln_single."] = argument
|
||||
arg_dict["caption_projection."] = argument
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["scale_shift_table"] = argument
|
||||
arg_dict["proj_out."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -183,4 +205,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeFlux1": ModelMergeFlux1,
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
||||
"ModelMergeLTXV": ModelMergeLTXV,
|
||||
}
|
||||
|
@ -82,8 +82,7 @@ class CLIPTextEncodeSD3:
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
|
@ -16,7 +16,8 @@ class SkipLayerGuidanceDiT:
|
||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
@ -26,7 +27,7 @@ class SkipLayerGuidanceDiT:
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
|
||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
@ -65,6 +66,11 @@ class SkipLayerGuidanceDiT:
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
if rescaling_scale != 0:
|
||||
factor = cond_pred.std() / cfg_result.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
cfg_result *= factor
|
||||
|
||||
return cfg_result
|
||||
|
||||
m = model.clone()
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import logging
|
||||
from spandrel import ModelLoader, ImageModelDescriptor
|
||||
from comfy import model_management
|
||||
|
@ -1,7 +1,5 @@
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import struct
|
||||
import comfy.utils
|
||||
import time
|
||||
|
||||
|
@ -16,7 +16,7 @@ import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy.cli_args import args
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@ -480,7 +480,7 @@ class PromptExecutor:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
|
||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
@ -527,7 +527,6 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
@ -589,8 +588,8 @@ def validate_inputs(prompt, item, validated):
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
received_type = r[val[1]]
|
||||
received_types[x] = received_type
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
|
36
fix_torch.py
36
fix_torch.py
@ -5,20 +5,24 @@ import ctypes
|
||||
import logging
|
||||
|
||||
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
|
||||
with open(test_file, 'rb') as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
def fix_pytorch_libomp():
|
||||
"""
|
||||
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
|
||||
"""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
try:
|
||||
mydll = ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError as e:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
||||
|
||||
with open(test_file, "rb") as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
break
|
||||
try:
|
||||
mydll = ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError as e:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
import mimetypes
|
||||
import logging
|
||||
from typing import Set, List, Dict, Tuple, Literal
|
||||
from typing import Literal
|
||||
from collections.abc import Collection
|
||||
|
||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
@ -133,7 +133,7 @@ def get_directory_by_type(type_name: str) -> str | None:
|
||||
return get_input_directory()
|
||||
return None
|
||||
|
||||
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
||||
"""
|
||||
Example:
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
|
@ -1,7 +1,5 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from comfy.taesd.taesd import TAESD
|
||||
import comfy.model_management
|
||||
|
9
main.py
9
main.py
@ -8,6 +8,11 @@ import time
|
||||
from comfy.cli_args import args
|
||||
from app.logger import setup_logger
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
|
||||
setup_logger(log_level=args.verbose)
|
||||
|
||||
@ -82,7 +87,8 @@ if __name__ == "__main__":
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
import fix_torch
|
||||
from fix_torch import fix_pytorch_libomp
|
||||
fix_pytorch_libomp()
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -154,7 +160,6 @@ def prompt_worker(q, server):
|
||||
if need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||
comfy.model_management.cleanup_models()
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
|
@ -1,2 +0,0 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
@ -1,234 +0,0 @@
|
||||
#NOTE: This was an experiment and WILL BE REMOVED
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import os
|
||||
import traceback
|
||||
import logging
|
||||
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||
from enum import Enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class DownloadStatusType(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadModelStatus():
|
||||
status: str
|
||||
progress_percentage: float
|
||||
message: str
|
||||
already_existed: bool = False
|
||||
|
||||
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
||||
self.status = status.value # Store the string value of the Enum
|
||||
self.progress_percentage = progress_percentage
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": self.status,
|
||||
"progress_percentage": self.progress_percentage,
|
||||
"message": self.message,
|
||||
"already_existed": self.already_existed
|
||||
}
|
||||
|
||||
|
||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_directory: str,
|
||||
folder_path: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||
"""
|
||||
Download a model file from a given URL into the models directory.
|
||||
|
||||
Args:
|
||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||
model_name (str):
|
||||
The name of the model file to be downloaded. This will be the filename on disk.
|
||||
model_url (str):
|
||||
The URL from which to download the model.
|
||||
model_directory (str):
|
||||
The subdirectory within the main models directory where the model
|
||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||
An asynchronous function to call with progress updates.
|
||||
folder_path (str);
|
||||
Path to which model folder should be used as the root.
|
||||
|
||||
Returns:
|
||||
DownloadModelStatus: The result of the download operation.
|
||||
"""
|
||||
if not validate_filename(model_name):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model name",
|
||||
False
|
||||
)
|
||||
|
||||
if not model_directory in folder_names_and_paths:
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||
False
|
||||
)
|
||||
|
||||
if not folder_path in get_folder_paths(model_directory):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||
False
|
||||
)
|
||||
|
||||
file_path = create_model_path(model_name, folder_path)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||
if existing_file:
|
||||
return existing_file
|
||||
|
||||
try:
|
||||
logging.info(f"Downloading {model_name} from {model_url}")
|
||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||
await progress_callback(model_name, status)
|
||||
|
||||
response = await model_download_request(model_url)
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
logging.error(error_message)
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(model_name, status)
|
||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in downloading model: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback)
|
||||
|
||||
|
||||
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
file_path = os.path.join(folder_path, model_name)
|
||||
|
||||
# Ensure the resulting path is still within the base directory
|
||||
abs_file_path = os.path.abspath(file_path)
|
||||
abs_base_dir = os.path.abspath(folder_path)
|
||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||
) -> Optional[DownloadModelStatus]:
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||
await progress_callback(model_name, status)
|
||||
return status
|
||||
return None
|
||||
|
||||
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
interval: float = 1.0) -> DownloadModelStatus:
|
||||
try:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded = 0
|
||||
last_update_time = time.time()
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||
await progress_callback(model_name, status)
|
||||
last_update_time = time.time()
|
||||
|
||||
temp_file_path = file_path + '.tmp'
|
||||
with open(temp_file_path, 'wb') as f:
|
||||
chunk_iterator = response.content.iter_chunked(8192)
|
||||
while True:
|
||||
try:
|
||||
chunk = await chunk_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
os.rename(temp_file_path, file_path)
|
||||
|
||||
await update_progress()
|
||||
|
||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||
await progress_callback(model_name, status)
|
||||
|
||||
return status
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
return await handle_download_error(e, model_name, progress_callback)
|
||||
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||
) -> DownloadModelStatus:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(model_name, status)
|
||||
return status
|
||||
|
||||
|
||||
def validate_filename(filename: str)-> bool:
|
||||
"""
|
||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to validate
|
||||
|
||||
Returns:
|
||||
bool: True if the filename is valid, False otherwise
|
||||
"""
|
||||
if not filename.lower().endswith(('.sft', '.safetensors')):
|
||||
return False
|
||||
|
||||
# Check if the filename is empty, None, or just whitespace
|
||||
if not filename or not filename.strip():
|
||||
return False
|
||||
|
||||
# Check for any directory traversal attempts or invalid characters
|
||||
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
|
||||
return False
|
||||
|
||||
# Check if the filename starts with a dot (hidden file)
|
||||
if filename.startswith('.'):
|
||||
return False
|
||||
|
||||
# Use a whitelist of allowed characters
|
||||
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
|
||||
return False
|
||||
|
||||
# Ensure the filename isn't too long
|
||||
if len(filename) > 255:
|
||||
return False
|
||||
|
||||
return True
|
54
nodes.py
54
nodes.py
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
import os
|
||||
@ -10,7 +11,7 @@ import time
|
||||
import random
|
||||
import logging
|
||||
|
||||
from PIL import Image, ImageOps, ImageSequence, ImageFile
|
||||
from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
@ -24,6 +25,7 @@ import comfy.sample
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@ -44,16 +46,16 @@ def interrupt_processing(value=True):
|
||||
|
||||
MAX_RESOLUTION=16384
|
||||
|
||||
class CLIPTextEncode:
|
||||
class CLIPTextEncode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
|
||||
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
RETURN_TYPES = (IO.CONDITIONING,)
|
||||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
@ -62,9 +64,8 @@ class CLIPTextEncode:
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class ConditioningCombine:
|
||||
@classmethod
|
||||
@ -301,7 +302,8 @@ class VAEDecodeTiled:
|
||||
def decode(self, vae, samples, tile_size, overlap=64):
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
|
||||
compression = vae.spacial_compression_decode()
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
@ -382,6 +384,7 @@ class InpaintModelConditioning:
|
||||
"vae": ("VAE", ),
|
||||
"pixels": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||
@ -390,7 +393,7 @@ class InpaintModelConditioning:
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae, mask):
|
||||
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
@ -414,7 +417,8 @@ class InpaintModelConditioning:
|
||||
out_latent = {}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
out_latent["noise_mask"] = mask
|
||||
if noise_mask:
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
@ -640,9 +644,7 @@ class LoraLoader:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
temp = self.loaded_lora
|
||||
self.loaded_lora = None
|
||||
del temp
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
@ -895,7 +897,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -913,6 +915,8 @@ class CLIPLoader:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||
elif type == "mochi":
|
||||
clip_type = comfy.sd.CLIPType.MOCHI
|
||||
elif type == "ltxv":
|
||||
clip_type = comfy.sd.CLIPType.LTXV
|
||||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
@ -966,15 +970,19 @@ class CLIPVisionEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"image": ("IMAGE",)
|
||||
"image": ("IMAGE",),
|
||||
"crop": (["center", "none"],)
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip_vision, image):
|
||||
output = clip_vision.encode_image(image)
|
||||
def encode(self, clip_vision, image, crop):
|
||||
crop_image = True
|
||||
if crop != "center":
|
||||
crop_image = False
|
||||
output = clip_vision.encode_image(image, crop=crop_image)
|
||||
return (output,)
|
||||
|
||||
class StyleModelLoader:
|
||||
@ -999,14 +1007,19 @@ class StyleModelApply:
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"style_model": ("STYLE_MODEL", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
|
||||
"strength_type": (["multiply"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_stylemodel"
|
||||
|
||||
CATEGORY = "conditioning/style_model"
|
||||
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
|
||||
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||
if strength_type == "multiply":
|
||||
cond *= strength
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||
@ -2134,6 +2147,9 @@ def init_builtin_extra_nodes():
|
||||
"nodes_torch_compile.py",
|
||||
"nodes_mochi.py",
|
||||
"nodes_slg.py",
|
||||
"nodes_mahiro.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user