Merge branch 'master' into upgrade

This commit is contained in:
Muhammad Hassan 2025-03-18 20:27:35 +05:00 committed by GitHub
commit b1270e56c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
235 changed files with 9914 additions and 230179 deletions

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
pause

View File

@ -12,7 +12,7 @@ on:
description: 'CUDA version'
required: true
type: string
default: "124"
default: "126"
python_minor:
description: 'Python minor version'
required: true
@ -22,7 +22,7 @@ on:
description: 'Python patch version'
required: true
type: string
default: "7"
default: "9"
jobs:

View File

@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt

View File

@ -18,7 +18,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip

58
.github/workflows/update-version.yml vendored Normal file
View File

@ -0,0 +1,58 @@
name: Update Version File
on:
pull_request:
paths:
- "pyproject.toml"
branches:
- master
jobs:
update-version:
runs-on: ubuntu-latest
# Don't run on fork PRs
if: github.event.pull_request.head.repo.full_name == github.repository
permissions:
pull-requests: write
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- name: Update comfyui_version.py
run: |
# Read version from pyproject.toml and update comfyui_version.py
python -c '
import tomllib
# Read version from pyproject.toml
with open("pyproject.toml", "rb") as f:
config = tomllib.load(f)
version = config["project"]["version"]
# Write version to comfyui_version.py
with open("comfyui_version.py", "w") as f:
f.write("# This file is automatically generated by the build process when version is\n")
f.write("# updated in pyproject.toml.\n")
f.write(f"__version__ = \"{version}\"\n")
'
- name: Commit changes
run: |
git config --local user.name "github-actions"
git config --local user.email "github-actions@github.com"
git fetch origin ${{ github.head_ref }}
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
git add comfyui_version.py
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
git push origin HEAD:${{ github.head_ref }}

View File

@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "124"
default: "126"
python_minor:
description: 'python minor version'
@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "7"
default: "9"
# push:
# branches:
# - master

View File

@ -7,19 +7,19 @@ on:
description: 'cuda version'
required: true
type: string
default: "124"
default: "128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
default: "13"
python_patch:
description: 'python patch version'
required: true
type: string
default: "4"
default: "2"
# push:
# branches:
# - master
@ -34,7 +34,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-depth: 30
persist-credentials: false
- uses: actions/setup-python@v5
with:
@ -74,7 +74,7 @@ jobs:
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
cd ComfyUI_windows_portable_nightly_pytorch

View File

@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "124"
default: "126"
python_minor:
description: 'python minor version'
@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "7"
default: "9"
# push:
# branches:
# - master

View File

@ -11,13 +11,14 @@
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered

View File

@ -1,4 +1,4 @@
## Prerequisites For Docker
## Prerequisites For Docker
1. NVIDIA drivers installed on your host system.
2. Docker installed on your host system.
3. NVIDIA Container Toolkit (nvidia-docker2) OR WSL2 installed so that you can use GPU inside

View File

@ -1,9 +1,9 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from folder_paths import folder_names_and_paths, get_directory_by_type
from api_server.services.terminal_service import TerminalService
import app.logger
import os
class InternalRoutes:
'''
@ -15,26 +15,10 @@ class InternalRoutes:
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@ -67,6 +51,20 @@ class InternalRoutes:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
@self.routes.get('/files/{directory_type}')
async def get_files(request: web.Request) -> web.Response:
directory_type = request.match_info['directory_type']
if directory_type not in ("output", "input", "temp"):
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
def get_app(self):
if self._app is None:
self._app = web.Application()

View File

@ -1,13 +0,0 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@ -25,10 +25,10 @@ class TerminalService:
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
changed = True
if lines != self.rows:
self.rows = lines
@ -48,9 +48,9 @@ class TerminalService:
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected

View File

@ -39,4 +39,4 @@ class FileSystemOperations:
"path": relative_path,
"type": "directory"
})
return file_list
return file_list

View File

@ -1,6 +1,7 @@
import os
import json
from aiohttp import web
import logging
class AppSettings():
@ -11,8 +12,12 @@ class AppSettings():
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
try:
with open(file) as f:
return json.load(f)
except:
logging.error(f"The user settings file is corrupted: {file}")
return {}
else:
return {}
@ -51,4 +56,4 @@ class AppSettings():
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)
return web.Response(status=200)

134
app/custom_node_manager.py Normal file
View File

@ -0,0 +1,134 @@
from __future__ import annotations
import os
import folder_paths
import glob
from aiohttp import web
import json
import logging
from functools import lru_cache
from utils.json_util import merge_json_recursive
# Extra locale files to load into main.json
EXTRA_LOCALE_FILES = [
"nodeDefs.json",
"commands.json",
"settings.json",
]
def safe_load_json_file(file_path: str) -> dict:
if not os.path.exists(file_path):
return {}
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
logging.error(f"Error loading {file_path}")
return {}
class CustomNodeManager:
@lru_cache(maxsize=1)
def build_translations(self):
"""Load all custom nodes translations during initialization. Translations are
expected to be loaded from `locales/` folder.
The folder structure is expected to be the following:
- custom_nodes/
- custom_node_1/
- locales/
- en/
- main.json
- commands.json
- settings.json
returned translations are expected to be in the following format:
{
"en": {
"nodeDefs": {...},
"commands": {...},
"settings": {...},
...{other main.json keys}
}
}
"""
translations = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
# Sort glob results for deterministic ordering
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
locales_dir = os.path.join(custom_node_dir, "locales")
if not os.path.exists(locales_dir):
continue
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
lang_code = os.path.basename(os.path.dirname(lang_dir))
if lang_code not in translations:
translations[lang_code] = {}
# Load main.json
main_file = os.path.join(lang_dir, "main.json")
node_translations = safe_load_json_file(main_file)
# Load extra locale files
for extra_file in EXTRA_LOCALE_FILES:
extra_file_path = os.path.join(lang_dir, extra_file)
key = extra_file.split(".")[0]
json_data = safe_load_json_file(extra_file_path)
if json_data:
node_translations[key] = json_data
if node_translations:
translations[lang_code] = merge_json_recursive(
translations[lang_code], node_translations
)
return translations
def add_routes(self, routes, webapp, loadedModules):
@routes.get("/workflow_templates")
async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
files = [
file
for folder in folder_paths.get_folder_paths("custom_nodes")
for file in glob.glob(
os.path.join(folder, "*/example_workflows/*.json")
)
]
workflow_templates_dict = (
{}
) # custom_nodes folder name -> example workflow names
for file in files:
custom_nodes_name = os.path.basename(
os.path.dirname(os.path.dirname(file))
)
workflow_name = os.path.splitext(os.path.basename(file))[0]
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
workflow_name
)
return web.json_response(workflow_templates_dict)
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
workflows_dir = os.path.join(module_dir, "example_workflows")
if os.path.exists(workflows_dir):
webapp.add_routes(
[
web.static(
"/api/workflow_templates/" + module_name, workflows_dir
)
]
)
@routes.get("/i18n")
async def get_i18n(request):
"""Returns translations from all custom nodes' locales folders."""
return web.json_response(self.build_translations())

View File

@ -3,16 +3,51 @@ import argparse
import logging
import os
import re
import sys
import tempfile
import zipfile
import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict, Optional
from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
def check_frontend_version():
"""Check if the frontend version is up to date."""
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
REQUEST_TIMEOUT = 10 # seconds
@ -109,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
sys.exit(-1)
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
@ -148,7 +191,8 @@ class FrontendManager:
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH
check_frontend_version()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string)
@ -201,4 +245,5 @@ class FrontendManager:
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH
check_frontend_version()
return cls.default_frontend_path()

View File

@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger.addHandler(stdout_handler)
logger.addHandler(stream_handler)
STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
STARTUP_WARNINGS.clear()

View File

@ -177,7 +177,7 @@ class ModelFileManager:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):

View File

@ -1,7 +1,6 @@
import argparse
import enum
import os
from typing import Optional
import comfy.options
@ -43,10 +42,11 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@ -129,7 +130,12 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -160,13 +166,14 @@ parser.add_argument(
""",
)
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path):
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
if not os.access(path, os.R_OK):
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
return path
parser.add_argument(
@ -176,7 +183,9 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
if comfy.options.args_parsing:
args = parser.parse_args()
@ -188,3 +197,17 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
if args.force_fp16:
args.fp16_unet = True
# '--fast' is not provided, use an empty set
if args.fast is None:
args.fast = set()
# '--fast' is provided with an empty list, enable all optimizations
elif args.fast == []:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)

View File

@ -97,14 +97,19 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
@ -115,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
@ -203,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@ -212,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else:
self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)
projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)

View File

@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
@ -42,10 +49,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@ -65,6 +73,7 @@ class ClipVisionModel():
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs
def convert_to_transformers(sd, prefix):
@ -104,9 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else:
return None

View File

@ -0,0 +1,19 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}

View File

@ -5,7 +5,7 @@ This module provides type hinting and concrete convenience types for node develo
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod

View File

@ -1,6 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
class UnetApplyFunction(Protocol):
@ -42,4 +42,5 @@ __all__ = [
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
FileLocator.__name__,
]

View File

@ -1,12 +1,12 @@
from comfy_types import IO, ComfyNodeABC, InputTypeDict
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires an IDE configured with analysis paths etc to be worth looking at.
* Not intended for use in ComfyUI.
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
* This node is intended as an example for developers only.
"""
DESCRIPTION = cleandoc(__doc__)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Literal, TypedDict
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
@ -26,6 +27,7 @@ class IO(StrEnum):
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
@ -67,12 +69,34 @@ class IO(StrEnum):
return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict):
route: str
"""The route to the remote source."""
refresh_button: bool
"""Specifies whether to show a refresh button in the UI below the widget."""
control_after_refresh: Literal["first", "last"]
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
timeout: int
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
max_retries: int
"""The maximum number of retries before aborting the request."""
refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
"""
default: bool | str | float | int | list | tuple
@ -101,7 +125,7 @@ class InputTypeOptions(TypedDict):
# default: bool
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str
label_off: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
@ -113,6 +137,29 @@ class InputTypeOptions(TypedDict):
# defaultVal: str
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
# class InputTypeCombo(InputTypeOptions):
image_upload: bool
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
image_folder: Literal["input", "output", "temp"]
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: RemoteInputOptions
"""Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict):
@ -133,7 +180,7 @@ class HiddenInputTypeDict(TypedDict):
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
@ -143,14 +190,14 @@ class InputTypeDict(TypedDict):
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
"""
DESCRIPTION: str
@ -167,7 +214,7 @@ class ComfyNodeABC(ABC):
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
@ -181,9 +228,9 @@ class ComfyNodeABC(ABC):
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
"""
return {"required": {}}
@ -198,7 +245,7 @@ class ComfyNodeABC(ABC):
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
@ -209,7 +256,7 @@ class ComfyNodeABC(ABC):
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
@ -227,7 +274,7 @@ class ComfyNodeABC(ABC):
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
RETURN_TYPES: tuple[IO]
@ -237,19 +284,19 @@ class ComfyNodeABC(ABC):
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
"""
RETURN_NAMES: tuple[str]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
"""
@ -267,8 +314,19 @@ class CheckLazyMixin:
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need
class FileLocator(TypedDict):
"""Provides type hinting for the file location"""
filename: str
"""The filename of the file."""
subfolder: str
"""The subfolder of the file."""
type: Literal["input", "output", "temp"]
"""The root folder of the file."""

View File

@ -3,9 +3,6 @@ import math
import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@ -46,7 +43,7 @@ class CONDCrossAttn(CONDRegular):
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
mult_min = math.lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
@ -57,7 +54,7 @@ class CONDCrossAttn(CONDRegular):
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []

View File

@ -418,10 +418,7 @@ def controlnet_config(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
@ -689,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()

View File

@ -4,105 +4,6 @@ import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
@ -213,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
@ -229,6 +131,7 @@ def cat_tensors(tensors):
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
capture_qkv_weight = {}
@ -284,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict

View File

@ -226,7 +226,7 @@ def model_wrapper(
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
``
The input `classifier_fn` has the following format:
``
@ -240,7 +240,7 @@ def model_wrapper(
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
@ -254,7 +254,7 @@ def model_wrapper(
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
@ -359,7 +359,7 @@ class UniPC:
max_val=1.,
variant='bh1',
):
"""Construct a UniPC.
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
@ -372,7 +372,7 @@ class UniPC:
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
@ -404,7 +404,7 @@ class UniPC:
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
@ -461,7 +461,7 @@ class UniPC:
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
@ -510,7 +510,7 @@ class UniPC:
col = torch.ones_like(rks)
for k in range(1, K + 1):
C.append(col)
col = col * rks / (k + 1)
col = col * rks / (k + 1)
C = torch.stack(C, dim=1)
if len(D1s) > 0:
@ -626,7 +626,7 @@ class UniPC:
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
@ -661,7 +661,7 @@ class UniPC:
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
@ -669,7 +669,7 @@ class UniPC:
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)

View File

@ -16,130 +16,171 @@ import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################
class EnumHookMode(enum.Enum):
'''
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
MinVram: No caching will occur for any operations related to hooks.
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
'''
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
'''
Hook types, each of which has different expected behavior.
'''
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
AdditionalModels = "add_models"
TransformerOptions = "transformer_options"
Injections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class EnumHookScope(enum.Enum):
'''
Determines if hook should be limited in its influence over sampling.
AllConditioning: hook will affect all conds used in sampling.
HookedOnly: hook will only affect the conds it was attached to.
'''
AllConditioning = "all_conditioning"
HookedOnly = "hooked_only"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how custom_should_register function can look like.'''
return True
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
'''Creates base dictionary for use with Hooks' target param.'''
d = {}
if target is not None:
d['target'] = target
d.update(kwargs)
return d
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type
'''Enum identifying the general class of this hook.'''
self.hook_ref = hook_ref if hook_ref else _HookRef()
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
self.hook_id = hook_id
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
def clone(self):
c: Hook = self.__class__()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
'''
Hook responsible for tracking weights to be applied to some model/clip.
Note, value of hook_scope is ignored and is treated as HookedOnly.
'''
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
target = target_dict.get('target', None)
if target == EnumWeightTarget.Clip:
strength = self._strength_clip
else:
strength = self._strength_model
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
if target == EnumWeightTarget.Clip:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
if target == EnumWeightTarget.Clip:
weights = self.weights_clip
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
registered.add(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
def clone(self):
c: WeightHook = super().clone()
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
@ -147,127 +188,158 @@ class WeightHook(Hook):
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
self.object_patches = object_patches
self.hook_scope = hook_scope
def clone(self):
c: ObjectPatchHook = super().clone()
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
class AdditionalModelsHook(Hook):
'''
Hook responsible for telling model management any additional models that should be loaded.
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AdditionalModels)
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
def clone(self):
c: AdditionalModelsHook = super().clone()
c.models = self.models.copy() if self.models else self.models
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
registered.add(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
class TransformerOptionsHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = transformers_dict
self.hook_scope = hook_scope
self._skip_adding = False
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
def clone(self):
c: TransformerOptionsHook = super().clone()
c.transformers_dict = self.transformers_dict
c._skip_adding = self._skip_adding
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
self._skip_adding = False
if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
# skip_adding if included in AllConditioning to avoid double loading
self._skip_adding = True
else:
add_model_options = {"to_load_options": self.transformers_dict}
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
if not self._skip_adding:
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
class InjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
self.hook_scope = hook_scope
def clone(self):
c: InjectionsHook = super().clone()
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
class HookGroup:
'''
Stores groups of hooks, and allows them to be queried by type.
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
always use the provided functions on HookGroup.
'''
def __init__(self):
self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
def remove(self, hook: Hook):
if hook in self.hooks:
self.hooks.remove(hook)
self._hook_dict[hook.hook_type].remove(hook)
def get_type(self, hook_type: EnumHookType):
return self._hook_dict.get(hook_type, [])
def contains(self, hook: Hook):
return hook in self.hooks
def is_subset_of(self, other: HookGroup):
self_hooks = set(self.hooks)
other_hooks = set(other.hooks)
return self_hooks.issubset(other_hooks)
def new_with_common_hooks(self, other: HookGroup):
c = HookGroup()
for hook in self.hooks:
if other.contains(hook):
c.add(hook.clone())
return c
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
def clone_and_combine(self, other: HookGroup):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
@ -275,36 +347,29 @@ class HookGroup:
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
# only care about WeightHooks, for now
for hook in self.get_type(EnumHookType.Weight):
hook: WeightHook
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
@ -336,7 +401,7 @@ class HookGroup:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
@ -365,10 +430,16 @@ class HookKeyframe:
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
if self.start_t > max_sigma:
return 0
return self.guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
@ -395,7 +466,7 @@ class HookKeyframeGroup:
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
@ -407,33 +478,40 @@ class HookKeyframeGroup:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_guarantee_steps(self):
for kf in self.keyframes:
if kf.guarantee_steps > 0:
return True
return False
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
max_sigma = torch.max(transformer_options["sample_sigmas"])
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
@ -446,7 +524,7 @@ class HookKeyframeGroup:
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
@ -509,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
sorted_list.extend(object_list)
return sorted_list
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
@ -535,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
@ -547,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
@ -565,7 +654,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
@ -599,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
if cache is None:
cache = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
_combine_hooks_from_values(n[1], values, cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
@ -651,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
cache = {}
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
@ -665,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
@ -679,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present

View File

@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
return append_zero(sigmas)
@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
return sigma_down, sigma_up
def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)
def default_noise_sampler(x, seed=None):
if seed is not None:
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
else:
generator = None
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
class BatchedBrownianTree:
@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -465,7 +475,7 @@ class DPMSolver(nn.Module):
return x_3, eps_cache
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')
@ -504,7 +514,7 @@ class DPMSolver(nn.Module):
return x
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
@ -676,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@ -750,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
@ -796,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None
@ -846,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@ -855,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@ -864,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
@ -882,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
@ -902,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -1153,7 +1167,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
@ -1179,7 +1194,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
@ -1249,3 +1265,160 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x
@torch.no_grad()
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
if cfg_pp:
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigma_down == 0 or old_denoised is None:
# Euler method
if cfg_pp:
d = to_d(x, sigmas[i], uncond_denoised)
x = denoised + d * sigma_down
else:
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
if cfg_pp:
x = x + (denoised - uncond_denoised)
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
else:
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if cfg_pp:
old_denoised = uncond_denoised
else:
old_denoised = denoised
return x
@torch.no_grad()
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
@torch.no_grad()
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
# Euler method
x = x + d * dt
else:
# Gradient estimation
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
old_denoised = None
old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x

View File

@ -3,6 +3,7 @@ import torch
class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
@ -143,6 +144,7 @@ class SD3(LatentFormat):
class StableAudio1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Flux(SD3):
latent_channels = 16
@ -178,6 +180,7 @@ class Flux(SD3):
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
def __init__(self):
self.scale_factor = 1.0
@ -219,6 +222,8 @@ class Mochi(LatentFormat):
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
@ -355,6 +360,7 @@ class LTXV(LatentFormat):
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
@ -376,3 +382,77 @@ class HunyuanVideo(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
[-0.0586, -0.0862, -0.3108],
[-0.4703, -0.4255, -0.3995],
[ 0.0803, 0.1963, 0.1001],
[-0.0820, -0.1050, 0.0400],
[ 0.2511, 0.3098, 0.2787],
[-0.1830, -0.2117, -0.0040],
[-0.0621, -0.2187, -0.0939],
[ 0.3619, 0.1082, 0.1455],
[ 0.3164, 0.3922, 0.2575],
[ 0.1152, 0.0231, -0.0462],
[-0.1434, -0.3609, -0.3665],
[ 0.0635, 0.1471, 0.1680],
[-0.3635, -0.1963, -0.3248],
[-0.1865, 0.0365, 0.2346],
[ 0.0447, 0.0994, 0.0881]
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean

View File

@ -19,6 +19,10 @@
import torch
from torch import nn
from torch.autograd import Function
import comfy.ops
ops = comfy.ops.disable_weight_init
class vector_quantize(Function):
@staticmethod
@ -121,15 +125,15 @@ class ResBlock(nn.Module):
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(c, c, kernel_size=3, groups=c)
ops.Conv2d(c, c, kernel_size=3, groups=c)
)
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden),
ops.Linear(c, c_hidden),
nn.GELU(),
nn.Linear(c_hidden, c),
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
@ -171,16 +175,16 @@ class StageA(nn.Module):
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
))
self.down_blocks = nn.Sequential(*down_blocks)
@ -191,7 +195,7 @@ class StageA(nn.Module):
# Decoder blocks
up_blocks = [nn.Sequential(
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
@ -199,11 +203,11 @@ class StageA(nn.Module):
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1))
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
@ -232,17 +236,17 @@ class Discriminator(nn.Module):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):

View File

@ -138,7 +138,7 @@ class StageB(nn.Module):
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
@ -148,7 +148,7 @@ class StageB(nn.Module):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)

View File

@ -142,7 +142,7 @@ class StageC(nn.Module):
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
@ -152,7 +152,7 @@ class StageC(nn.Module):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)

View File

@ -19,6 +19,9 @@ import torch
import torchvision
from torch import nn
import comfy.ops
ops = comfy.ops.disable_weight_init
# EfficientNet
class EfficientNetEncoder(nn.Module):
@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x):
x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
o = self.mapper(self.backbone(x))
return o
@ -44,39 +47,39 @@ class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):

808
comfy/ldm/cosmos/blocks.py Normal file
View File

@ -0,0 +1,808 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import logging
import numpy as np
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}):
if name == "I":
return nn.Identity()
elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else:
raise ValueError(f"Normalization {name} not found")
class BaseAttentionOp(nn.Module):
def __init__(self):
super().__init__()
class Attention(nn.Module):
"""
Generalized attention impl.
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
If `context_dim` is None, self-attention is assumed.
Parameters:
query_dim (int): Dimension of each query vector.
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
heads (int, optional): Number of attention heads. Defaults to 8.
dim_head (int, optional): Dimension of each head. Defaults to 64.
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
Defaults to "SSI".
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
Defaults to 'per_head'. Only support 'per_head'.
Examples:
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
>>> query = torch.randn(10, 128) # Batch size of 10
>>> context = torch.randn(10, 256) # Batch size of 10
>>> output = attn(query, context) # Perform the attention operation
Note:
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
"""
def __init__(
self,
query_dim: int,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_op: Optional[BaseAttentionOp] = None,
qkv_bias: bool = False,
out_bias: bool = False,
qkv_norm: str = "SSI",
qkv_norm_mode: str = "per_head",
backend: str = "transformer_engine",
qkv_format: str = "bshd",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.is_selfattn = context_dim is None # self attention
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.heads = heads
self.dim_head = dim_head
self.qkv_norm_mode = qkv_norm_mode
self.qkv_format = qkv_format
if self.qkv_norm_mode == "per_head":
norm_dim = dim_head
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
self.backend = backend
self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim),
)
self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim),
)
self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim),
)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
nn.Dropout(dropout),
)
def cal_qkv(
self, x, context=None, mask=None, rope_emb=None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
del kwargs
"""
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
Before 07/24/2024, these modules normalize across all heads.
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
we support to normalize per head.
To keep the checkpoint copatibility with the previous code,
we keep the nn.Sequential but call the projection and the normalization layers separately.
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
"""
if self.qkv_norm_mode == "per_head":
q = self.to_q[0](x)
context = x if context is None else context
k = self.to_k[0](context)
v = self.to_v[0](context)
q, k, v = map(
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
(q, k, v),
)
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
q = self.to_q[1](q)
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
# apply_rotary_pos_emb inlined
q_shape = q.shape
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
# apply_rotary_pos_emb inlined
k_shape = k.shape
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
return q, k, v
def forward(
self,
x,
context=None,
mask=None,
rope_emb=None,
**kwargs,
):
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
class FeedForward(nn.Module):
"""
Transformer FFN with optional gating
Parameters:
d_model (int): Dimensionality of input features.
d_ff (int): Dimensionality of the hidden layer.
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
activation (callable, optional): The activation function applied after the first linear layer.
Defaults to nn.ReLU().
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
Defaults to False.
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
Example:
>>> ff = FeedForward(d_model=512, d_ff=2048)
>>> x = torch.randn(64, 10, 512) # Example input tensor
>>> output = ff(x)
>>> print(output.shape) # Expected shape: (64, 10, 512)
"""
def __init__(
self,
d_model: int,
d_ff: int,
dropout: float = 0.1,
activation=nn.ReLU(),
is_gated: bool = False,
bias: bool = False,
weight_args={},
operations=None,
) -> None:
super().__init__()
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
self.dropout = nn.Dropout(dropout)
self.activation = activation
self.is_gated = is_gated
if is_gated:
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
def forward(self, x: torch.Tensor):
g = self.activation(self.layer1(x))
if self.is_gated:
x = g * self.linear_gate(x)
else:
x = g
assert self.dropout.p == 0.0, "we skip dropout"
return self.layer2(x)
class GPT2FeedForward(FeedForward):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
super().__init__(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=nn.GELU(),
is_gated=False,
bias=bias,
weight_args=weight_args,
operations=operations,
)
def forward(self, x: torch.Tensor):
assert self.dropout.p == 0.0, "we skip dropout"
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Timesteps(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_3D = emb
emb_B_D = sample
else:
emb_B_D = emb
adaln_lora_B_3D = None
return emb_B_D, adaln_lora_B_3D
class FourierFeatures(nn.Module):
"""
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
[B] -> [B, D]
Parameters:
num_channels (int): The number of Fourier features to generate.
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
the variance of the features. Defaults to False.
Example:
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
>>> x = torch.randn(10, 256) # Example input tensor
>>> output = layer(x)
>>> print(output.shape) # Expected shape: (10, 256)
"""
def __init__(self, num_channels, bandwidth=1, normalize=False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
def forward(self, x, gain: float = 1.0):
"""
Apply the Fourier feature transformation to the input tensor.
Args:
x (torch.Tensor): The input tensor.
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
Returns:
torch.Tensor: The transformed tensor, with Fourier features applied.
"""
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size,
temporal_patch_size,
in_channels=3,
out_channels=768,
bias=True,
weight_args={},
operations=None,
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
),
)
self.out = nn.Identity()
def forward(self, x):
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return self.out(x)
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size,
spatial_patch_size,
temporal_patch_size,
out_channels,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
)
def forward(
self,
x_BT_HW_D,
emb_B_D,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_3D is not None
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
2, dim=1
)
else:
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
B = emb_B_D.shape[0]
T = x_BT_HW_D.shape[0] // B
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
x_BT_HW_D = self.linear(x_BT_HW_D)
return x_BT_HW_D
class VideoAttn(nn.Module):
"""
Implements video attention with optional cross-attention capabilities.
This module processes video features while maintaining their spatio-temporal structure. It can perform
self-attention within the video features or cross-attention with external context features.
Parameters:
x_dim (int): Dimension of input feature vectors
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
num_heads (int): Number of attention heads
bias (bool): Whether to include bias in attention projections. Default: False
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
Input shape:
- x: (T, H, W, B, D) video features
- context (optional): (M, B, D) context features for cross-attention
where:
T: temporal dimension
H: height
W: width
B: batch size
D: feature dimension
M: context sequence length
"""
def __init__(
self,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
bias: bool = False,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.x_format = x_format
self.attn = Attention(
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_bias=bias,
qkv_norm="RRI",
out_bias=bias,
qkv_norm_mode=qkv_norm_mode,
qkv_format="sbhd",
weight_args=weight_args,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for video attention.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
where M is the sequence length of the context.
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor with applied attention, maintaining the input shape.
"""
x_T_H_W_B_D = x
context_M_B_D = context
T, H, W, B, D = x_T_H_W_B_D.shape
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
x_THW_B_D = self.attn(
x_THW_B_D,
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
def adaln_norm_state(norm_state, x, scale, shift):
normalized = norm_state(x)
return normalized * (1 + scale) + shift
class DITBuildingBlock(nn.Module):
"""
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
attention and MLP operations with adaptive layer normalization.
Parameters:
block_type (str): Type of block - one of:
- "cross_attn"/"ca": Cross-attention
- "full_attn"/"fa": Full self-attention
- "mlp"/"ff": MLP/feedforward block
x_dim (int): Dimension of input features
context_dim (Optional[int]): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
bias (bool): Whether to use bias in layers. Default: False
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
"""
def __init__(
self,
block_type: str,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
mlp_ratio: float = 4.0,
bias: bool = False,
mlp_dropout: float = 0.0,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
) -> None:
block_type = block_type.lower()
super().__init__()
self.x_format = x_format
if block_type in ["cross_attn", "ca"]:
self.block = VideoAttn(
x_dim,
context_dim,
num_heads,
bias=bias,
qkv_norm_mode=qkv_norm_mode,
x_format=self.x_format,
weight_args=weight_args,
operations=operations,
)
elif block_type in ["full_attn", "fa"]:
self.block = VideoAttn(
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
)
elif block_type in ["mlp", "ff"]:
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
else:
raise ValueError(f"Unknown block type: {block_type}")
self.block_type = block_type
self.use_adaln_lora = use_adaln_lora
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.n_adaln_chunks = 3
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
crossattn_emb (Tensor): Tensor for cross-attention blocks.
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor after processing through the configured block and adaptive normalization.
"""
if self.use_adaln_lora:
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
self.n_adaln_chunks, dim=1
)
else:
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
)
if self.block_type in ["mlp", "ff"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
)
elif self.block_type in ["full_attn", "fa"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
return x
class GeneralDITTransformerBlock(nn.Module):
"""
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
Each block in the sequence is specified by a block configuration string.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention blocks
num_heads (int): Number of attention heads
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
full-attention, then MLP)
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
The block_config string uses "-" to separate block types:
- "ca"/"cross_attn": Cross-attention block
- "fa"/"full_attn": Full self-attention block
- "mlp"/"ff": MLP/feedforward block
Example:
block_config = "ca-fa-mlp" creates a sequence of:
1. Cross-attention block
2. Full self-attention block
3. MLP block
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
block_config: str,
mlp_ratio: float = 4.0,
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
):
super().__init__()
self.blocks = nn.ModuleList()
self.x_format = x_format
for block_type in block_config.split("-"):
self.blocks.append(
DITBuildingBlock(
block_type,
x_dim,
context_dim,
num_heads,
mlp_ratio,
x_format=self.x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
)
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for block in self.blocks:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,377 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The patcher and unpatcher implementation for 2D and 3D data.
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
One on the rows and one on the columns.
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
as we need to support downsampling for more than 2x.
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
"""
import torch
import torch.nn.functional as F
from einops import rearrange
_WAVELETS = {
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
"rearrange": torch.tensor([1.0, 1.0]),
}
_PERSISTENT = False
class Patcher(torch.nn.Module):
"""A module to convert image tensors into patches using torch operations.
The main difference from `class Patching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Patching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._haar(x)
elif self.patch_method == "rearrange":
return self._arrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _dwt(self, x, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
if rescale:
out = out / 2
return out
def _haar(self, x):
for _ in self.range:
x = self._dwt(x, rescale=True)
return x
def _arrange(self, x):
x = rearrange(
x,
"b c (h p1) (w p2) -> b (c p1 p2) h w",
p1=self.patch_size,
p2=self.patch_size,
).contiguous()
return x
class Patcher3D(Patcher):
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
self.register_buffer(
"patch_size_buffer",
patch_size * torch.ones([1], dtype=torch.int32),
persistent=_PERSISTENT,
)
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
# Handles temporal axis.
x = F.pad(
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
).to(dtype)
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
# Handles spatial axes.
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
if rescale:
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
return out
def _haar(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
for _ in self.range:
x = self._dwt(x, "haar", rescale=True)
return x
def _arrange(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
x = rearrange(
x,
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
).contiguous()
return x
class UnPatcher(torch.nn.Module):
"""A module to convert patches into image tensorsusing torch operations.
The main difference from `class Unpatching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Unpatching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._ihaar(x)
elif self.patch_method == "rearrange":
return self._iarrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1] // 4
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
# Inverse transform.
yl = torch.nn.functional.conv_transpose2d(
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yl += torch.nn.functional.conv_transpose2d(
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh = torch.nn.functional.conv_transpose2d(
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh += torch.nn.functional.conv_transpose2d(
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
y = torch.nn.functional.conv_transpose2d(
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
y += torch.nn.functional.conv_transpose2d(
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
if rescale:
y = y * 2
return y
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
)
return x
class UnPatcher3D(UnPatcher):
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hl = hl.to(dtype=dtype)
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlll
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhl
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhll
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhl
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xll
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xlh
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhl
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
del xl
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
if rescale:
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
return x
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
x = x[:, :, self.patch_size - 1 :, ...]
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
)
x = x[:, :, self.patch_size - 1 :, ...]
return x

View File

@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for the networks module."""
from typing import Any
import torch
from einops import rearrange
import comfy.ops
ops = comfy.ops.disable_weight_init
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size = x.shape[0]
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size, height = x.shape[0], x.shape[-2]
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
def cast_tuple(t: Any, length: int = 1) -> Any:
return t if isinstance(t, tuple) else ((t,) * length)
def replication_pad(x):
return torch.cat([x[:, :, :1, ...], x], dim=2)
def divisible_by(num: int, den: int) -> bool:
return (num % den) == 0
def is_odd(n: int) -> bool:
return not divisible_by(n, 2)
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class CausalNormalize(torch.nn.Module):
def __init__(self, in_channels, num_groups=1):
super().__init__()
self.norm = ops.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
)
self.num_groups = num_groups
def forward(self, x):
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
# All new models should use num_groups=1, otherwise causality is not guaranteed.
if self.num_groups == 1:
x, batch_size = time2batch(x)
return batch2time(self.norm(x), batch_size)
return self.norm(x)
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)

514
comfy/ldm/cosmos/model.py Normal file
View File

@ -0,0 +1,514 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
"""
from typing import Optional, Tuple
import torch
from einops import rearrange
from torch import nn
from torchvision import transforms
from enum import Enum
import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
PatchEmbed,
TimestepEmbedding,
Timesteps,
)
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
class DataType(Enum):
IMAGE = "image"
VIDEO = "video"
class GeneralDIT(nn.Module):
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
block_config (str): Configuration of the transformer block. See Notes for supported block types.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
affline_emb_norm (bool): Whether to normalize affine embeddings.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
Notes:
Supported block types in block_config:
* cross_attn, ca: Cross attention
* full_attn: Full attention on all flattened tokens
* mlp, ff: Feed forward block
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
block_config: str = "FA-CA-MLP",
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
block_x_format: str = "BTHWD",
# cross attention settings
crossattn_emb_channels: int = 1024,
use_cross_attn_mask: bool = False,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_per_block_abs_pos_emb_type: str = "sincos",
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.use_cross_attn_mask = use_cross_attn_mask
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.affline_emb_norm = affline_emb_norm
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.dtype = dtype
weight_args = {"device": device, "dtype": dtype}
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
bias=False,
weight_args=weight_args,
operations=operations,
)
self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.ModuleList(
[Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
)
self.blocks = nn.ModuleDict()
for idx in range(num_blocks):
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
block_config=block_config,
mlp_ratio=mlp_ratio,
x_format=self.block_x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
else:
self.affline_norm = nn.Identity()
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
device=device,
)
self.pos_embedder = cls_type(
**kwargs,
)
if self.extra_per_block_abs_pos_emb:
assert self.extra_per_block_abs_pos_emb_type in [
"learnable",
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is not None:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
else:
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
if "fps_aware" in self.pos_emb_cls:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
else:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def decoder_head(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
crossattn_mask: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
del crossattn_emb, crossattn_mask
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
# This is to ensure x_BT_HW_D has the correct shape because
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
x_BT_HW_D = x_BT_HW_D.view(
B * T_before_patchify // self.patch_temporal,
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
-1,
)
x_B_D_T_H_W = rearrange(
x_BT_HW_D,
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
H=H_before_patchify // self.patch_spatial,
W=W_before_patchify // self.patch_spatial,
t=self.patch_temporal,
B=B,
)
return x_B_D_T_H_W
def forward_before_blocks(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
"""
del kwargs
assert isinstance(
data_type, DataType
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
original_shape = x.shape
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x,
fps=fps,
padding_mask=padding_mask,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
)
# logging affline scale information
affline_scale_log_info = {}
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
affline_emb_B_D = timesteps_B_D
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
if scalar_feature is not None:
raise NotImplementedError("Scalar feature is not implemented yet.")
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
if self.use_cross_attn_mask:
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
else:
crossattn_mask = None
if self.blocks["block0"].x_format == "THWBD":
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
)
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
if crossattn_mask:
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
elif self.blocks["block0"].x_format == "BTHWD":
x = x_B_T_H_W_D
else:
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
output = {
"x": x,
"affline_emb_B_D": affline_emb_B_D,
"crossattn_emb": crossattn_emb,
"crossattn_mask": crossattn_mask,
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
"adaln_lora_B_3D": adaln_lora_B_3D,
"original_shape": original_shape,
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
return output
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
we need forward_before_blocks pass to the forward_before_blocks function.
"""
crossattn_emb = context
crossattn_mask = attention_mask
inputs = self.forward_before_blocks(
x=x,
timesteps=timesteps,
crossattn_emb=crossattn_emb,
crossattn_mask=crossattn_mask,
fps=fps,
image_size=image_size,
padding_mask=padding_mask,
scalar_feature=scalar_feature,
data_type=data_type,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
condition_video_augment_sigma=condition_video_augment_sigma,
**kwargs,
)
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
inputs["x"],
inputs["affline_emb_B_D"],
inputs["crossattn_emb"],
inputs["crossattn_mask"],
inputs["rope_emb_L_1_1_D"],
inputs["adaln_lora_B_3D"],
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
x_B_D_T_H_W = self.decoder_head(
x_B_T_H_W_D=x_B_T_H_W_D,
emb_B_D=affline_emb_B_D,
crossattn_emb=None,
origin_shape=original_shape,
crossattn_mask=None,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x_B_D_T_H_W

View File

@ -0,0 +1,208 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from einops import rearrange, repeat
from torch import nn
import math
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
"""
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
Args:
x (torch.Tensor): The input tensor to normalize.
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
eps (float, optional): A small constant to ensure numerical stability during division.
Returns:
torch.Tensor: The normalized tensor.
"""
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class VideoPositionEmb(nn.Module):
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
"""
It delegates the embedding generation to generate_embeddings function.
"""
B_T_H_W_C = x_B_T_H_W_C.shape
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
return embeddings
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
raise NotImplementedError
class VideoRopePosition3DEmb(VideoPositionEmb):
def __init__(
self,
*, # enforce keyword arguments
head_dim: int,
len_h: int,
len_w: int,
len_t: int,
base_fps: int = 24,
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
dim = head_dim
dim_h = dim // 6 * 2
dim_w = dim_h
dim_t = dim - 2 * dim_h
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
self.register_buffer(
"dim_spatial_range",
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
persistent=False,
)
self.register_buffer(
"dim_temporal_range",
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
persistent=False,
)
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
def generate_embeddings(
self,
B_T_H_W_C: torch.Size,
fps: Optional[torch.Tensor] = None,
h_ntk_factor: Optional[float] = None,
w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None,
device=None,
dtype=None,
):
"""
Generate embeddings for the given input size.
Args:
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
Returns:
Not specified in the original code snippet.
"""
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
h_theta = 10000.0 * h_ntk_factor
w_theta = 10000.0 * w_ntk_factor
t_theta = 10000.0 * t_ntk_factor
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
B, T, H, W, _ = B_T_H_W_C
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
em_T_H_W_D = torch.cat(
[
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
]
, dim=-2,
)
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
class LearnablePosEmbAxis(VideoPositionEmb):
def __init__(
self,
*, # enforce keyword arguments
interpolation: str,
model_channels: int,
len_h: int,
len_w: int,
len_t: int,
device=None,
dtype=None,
**kwargs,
):
"""
Args:
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
"""
del kwargs # unused
super().__init__()
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop":
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
)
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
else:
raise ValueError(f"Unknown interpolation method {self.interpolation}")
return normalize(emb, dim=-1, eps=1e-6)

131
comfy/ldm/cosmos/vae.py Normal file
View File

@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
import logging
import torch
from torch import nn
from enum import Enum
import math
from .cosmos_tokenizer.layers3d import (
EncoderFactorized,
DecoderFactorized,
CausalConv3d,
)
class IdentityDistribution(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, parameters):
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
class GaussianDistribution(torch.nn.Module):
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
super().__init__()
self.min_logvar = min_logvar
self.max_logvar = max_logvar
def sample(self, mean, logvar):
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
def forward(self, parameters):
mean, logvar = torch.chunk(parameters, 2, dim=1)
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
return self.sample(mean, logvar), (mean, logvar)
class ContinuousFormulation(Enum):
VAE = GaussianDistribution
AE = IdentityDistribution
class CausalContinuousVideoTokenizer(nn.Module):
def __init__(
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
) -> None:
super().__init__()
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
self.latent_channels = latent_channels
self.sigma_data = 0.5
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
self.encoder = EncoderFactorized(
z_channels=z_factor * z_channels, **kwargs
)
if kwargs.get("temporal_compression", 4) == 4:
kwargs["channels_mult"] = [2, 4]
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
self.decoder = DecoderFactorized(
z_channels=z_channels, **kwargs
)
self.quant_conv = CausalConv3d(
z_factor * z_channels,
z_factor * latent_channels,
kernel_size=1,
padding=0,
)
self.post_quant_conv = CausalConv3d(
latent_channels, z_channels, kernel_size=1, padding=0
)
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)
latent_temporal_chunk = 16
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
z, posteriors = self.distribution(moments)
latent_ch = z.shape[1]
latent_t = z.shape[2]
in_dtype = z.dtype
mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
return ((z - mean) / std) * self.sigma_data
def decode(self, z):
in_dtype = z.dtype
latent_ch = z.shape[1]
latent_t = z.shape[2]
mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
z = z / self.sigma_data
z = z * std + mean
z = self.post_quant_conv(z)
return self.decoder(z)

View File

@ -105,7 +105,9 @@ class Modulation(nn.Module):
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
if vec.ndim == 2:
vec = vec[:, None, :]
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
@ -113,6 +115,20 @@ class Modulation(nn.Module):
)
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
else:
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
return tensor
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@ -228,10 +244,9 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
@ -240,7 +255,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@ -253,8 +268,11 @@ class LastLayer(nn.Module):
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2:
vec = vec[:, None, :]
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
x = self.linear(x)
return x

View File

@ -5,8 +5,16 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
q_shape = q.shape
k_shape = k.shape
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
@ -15,7 +23,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
device = torch.device("cpu")
else:
device = pos.device
@ -29,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -109,15 +109,17 @@ class Flux(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
@ -168,7 +170,7 @@ class Flux(nn.Module):
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
@ -186,7 +188,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

View File

@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
control=None,
transformer_options={},
) -> Tensor:
@ -237,12 +238,21 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
@ -265,14 +275,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet
control_i = control.get("input")
@ -287,13 +297,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
if control is not None: # Controlnet
control_o = control.get("output")
@ -304,17 +314,17 @@ class HunyuanVideo(nn.Module):
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@ -326,5 +336,5 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
return out

View File

@ -159,7 +159,7 @@ class CrossAttention(nn.Module):
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
v = v.transpose(-2, -3).contiguous()
v = v.transpose(-2, -3).contiguous()
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)

View File

@ -7,7 +7,7 @@ from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
def get_timestep_embedding(
@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
@ -416,51 +420,31 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
scale_factors=self.vae_scale_factors,
causal_fix=self.causal_temporal_positioning,
)
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
@ -520,8 +504,4 @@ class LTXVModel(torch.nn.Module):
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@ -6,16 +6,29 @@ from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
def latent_to_pixel_coords(
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
) -> Tensor:
"""
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class Patchifier(ABC):
@ -44,29 +57,26 @@ class Patchifier(ABC):
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
def get_latent_coords(
self, latent_num_frames, latent_height, latent_width, batch_size, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
"""
Return a tensor of shape [batch_size, 3, num_patches] containing the
top-left corner latent coordinates of each latent patch.
The tensor is repeated for each batch element.
"""
latent_sample_coords = torch.meshgrid(
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
torch.arange(0, latent_height, self._patch_size[1], device=device),
torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
)
return latent_coords
class SymmetricPatchifier(Patchifier):
@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
b, _, f, h, w = latents.shape
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
return latents, latent_coords
def unpatchify(
self,

View File

@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
padding_mode=spatial_padding_mode,
groups=groups,
)

View File

@ -1,13 +1,15 @@
from __future__ import annotations
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
@ -32,7 +34,7 @@ class Encoder(nn.Module):
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
@ -40,12 +42,13 @@ class Encoder(nn.Module):
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@ -65,6 +68,7 @@ class Encoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
@ -82,6 +86,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@ -92,6 +97,7 @@ class Encoder(nn.Module):
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
@ -101,6 +107,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
@ -110,6 +117,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
@ -119,6 +127,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@ -129,6 +138,34 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
@ -152,10 +189,18 @@ class Encoder(nn.Module):
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@ -197,6 +242,15 @@ class Encoder(nn.Module):
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample
@ -231,7 +285,7 @@ class Decoder(nn.Module):
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
@ -239,6 +293,7 @@ class Decoder(nn.Module):
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@ -264,6 +319,7 @@ class Decoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
@ -283,6 +339,7 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
@ -294,6 +351,7 @@ class Decoder(nn.Module):
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
@ -306,14 +364,21 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
@ -323,6 +388,7 @@ class Decoder(nn.Module):
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
@ -340,7 +406,13 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * math.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // math.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
):
super().__init__()
self.stride = stride
@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
@ -558,7 +695,7 @@ class DepthToSpaceUpsample(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@ -801,9 +941,44 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self, version=0):
def __init__(self, version=0, config=None):
super().__init__()
if config is None:
config = self.guess_config(version)
self.timestep_conditioning = config.get("timestep_conditioning", False)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.per_channel_statistics = processor()
def guess_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
"use_quant_conv": False,
"causal_decoder": False,
}
else:
elif version == 1:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
"causal_decoder": False,
"timestep_conditioning": True,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}]
],
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True
}
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)

View File

@ -17,7 +17,11 @@ def make_conv_nd(
groups=1,
bias=True,
causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
@ -28,6 +32,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
@ -40,6 +45,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return ops.Conv3d(
in_channels=in_channels,
@ -50,6 +56,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
@ -59,6 +66,7 @@ def make_conv_nd(
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape
@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x

622
comfy/ldm/lumina/model.py Normal file
View File

@ -0,0 +1,622 @@
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
from __future__ import annotations
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
#############################################################################
# Core NextDiT Model #
#############################################################################
class JointAttention(nn.Module):
"""Multi-head attention module."""
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
operation_settings={},
):
"""
Initialize the Attention module.
Args:
dim (int): Number of input dimensions.
n_heads (int): Number of heads.
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
"""
super().__init__()
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.n_local_heads = n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads
self.qkv = operation_settings.get("operations").Linear(
dim,
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
if qk_norm:
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
else:
self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
x_mask:
freqs_cis:
Returns:
"""
bsz, seqlen, _ = x.shape
xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
return self.out(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
operation_settings={},
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
dimension. Defaults to None.
"""
super().__init__()
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = operation_settings.get("operations").Linear(
dim,
hidden_dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.w2 = operation_settings.get("operations").Linear(
hidden_dim,
dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.w3 = operation_settings.get("operations").Linear(
dim,
hidden_dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class JointTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
operation_settings={},
) -> None:
"""
Initialize a TransformerBlock.
Args:
layer_id (int): Identifier for the layer.
dim (int): Embedding dimension of the input features.
n_heads (int): Number of attention heads.
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int):
ffn_dim_multiplier (float):
norm_eps (float):
"""
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
)
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
Returns:
torch.Tensor: Output tensor after applying attention and
feedforward layers.
"""
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
)
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
)
)
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
return x
class FinalLayer(nn.Module):
"""
The final layer of NextDiT.
"""
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
elementwise_affine=False,
eps=1e-6,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.linear = operation_settings.get("operations").Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(hidden_size, 1024),
hidden_size,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
x = self.linear(x)
return x
class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 4,
dim: int = 4096,
n_layers: int = 32,
n_refiner_layers: int = 2,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
out_features=dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
operation_settings.get("operations").Linear(
cap_feat_dim,
dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
def unpatchify(
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
) -> List[torch.Tensor]:
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
pH = pW = self.patch_size
imgs = []
for i in range(x.size(0)):
H, W = img_size[i]
begin = cap_size[i]
end = begin + (H // pH) * (W // pW)
imgs.append(
x[i][begin:end]
.view(H // pH, W // pW, pH, pW, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
)
if return_tensor:
imgs = torch.stack(imgs, dim=0)
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
dtype = x[0].dtype
if cap_mask is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
else:
l_effective_cap_len = [num_tokens] * bsz
if cap_mask is not None and not torch.is_floating_point(cap_mask):
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
"""
Forward pass of NextDiT.
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of text tokens/features
"""
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
return -x

View File

@ -1,4 +1,6 @@
import math
import sys
import torch
import torch.nn.functional as F
from torch import nn, einsum
@ -16,7 +18,18 @@ if model_management.xformers_enabled():
import xformers.ops
if model_management.sage_attention_enabled():
from sageattention import sageattn
try:
from sageattention import sageattn
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from comfy.cli_args import args
import comfy.ops
@ -24,38 +37,24 @@ ops = comfy.ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
def get_attn_precision(attn_precision):
def get_attn_precision(attn_precision, current_dtype):
if args.dont_upcast_attention:
return None
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
return attn_precision
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
@ -89,8 +88,8 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
b, _, _, dim_head = q.shape
@ -142,17 +141,24 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
b, _, _, dim_head = query.shape
@ -215,12 +221,14 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
)
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
if skip_output_reshape:
hidden_states = hidden_states.unflatten(0, (-1, heads))
else:
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
b, _, _, dim_head = q.shape
@ -326,12 +334,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
BROKEN_XFORMERS = False
@ -342,7 +356,7 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -395,9 +409,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
@ -408,7 +425,7 @@ else:
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -429,9 +446,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
@ -450,7 +468,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -473,11 +491,72 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
@ -489,6 +568,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch

View File

@ -321,7 +321,7 @@ class SelfAttention(nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
):
"""
Initialize the RMSNorm normalization layer.

View File

@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled_vae():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
self.optimized_attention = vae_attention()
def forward(self, x):
h_ = x
@ -699,9 +702,6 @@ class Decoder(nn.Module):
padding=1)
def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None

View File

@ -17,10 +17,10 @@ import math
import logging
try:
from typing import Optional, NamedTuple, List, Protocol
from typing import Optional, NamedTuple, List, Protocol
except ImportError:
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from typing import List

View File

@ -194,4 +194,4 @@ class AdamWwithEMAandWings(optim.Optimizer):
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss
return loss

485
comfy/ldm/wan/model.py Normal file
View File

@ -0,0 +1,485 @@
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
import comfy.ldm.common_dit
import comfy.model_management
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float32)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n * d)
return q, k, v
q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
)
x = self.o(x)
return x
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
# compute query, key, value
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = self.o(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :257]
context = context[:, 257:]
# compute query, key, value
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
# output
x = x + img_x
x = self.o(x)
return x
WAN_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention,
}
class WanAttentionBlock(nn.Module):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps, operation_settings=operation_settings)
self.norm3 = operation_settings.get("operations").LayerNorm(
dim, eps,
elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
num_heads,
(-1, -1),
qk_norm,
eps, operation_settings=operation_settings)
self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn = nn.Sequential(
operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
# modulation
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(
self,
x,
e,
freqs,
context,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0],
freqs)
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# modulation
self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanModel(torch.nn.Module):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None,
):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
assert model_type in ['t2v', 'i2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = operations.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
self.text_embedding = nn.Sequential(
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
self.time_embedding = nn.Sequential(
operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
d = dim // num_heads
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
else:
self.img_emb = None
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
transformer_options={},
):
r"""
Forward pass through the diffusion model
Args:
x (Tensor):
List of input video tensors with shape [B, C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [B, L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
"""
c = self.out_dim
u = x
b = u.shape[0]
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return u

567
comfy/ldm/wan/vae.py Normal file
View File

@ -0,0 +1,567 @@
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class CausalConv3d(ops.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).chunk(3, dim=1)
x = self.optimized_attention(q, k, v)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -307,7 +307,6 @@ def model_lora_keys_unet(model, key_map={}):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
@ -327,6 +326,13 @@ def model_lora_keys_unet(model, key_map={}):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
if isinstance(model, comfy.model_base.StableCascade_C):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_prior_unet_{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:

View File

@ -33,6 +33,9 @@ import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.model_management
import comfy.patcher_extension
@ -105,7 +108,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
@ -147,7 +150,9 @@ class BaseModel(torch.nn.Module):
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
if context is not None:
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
@ -156,15 +161,16 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype)
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def process_timestep(self, timestep, **kwargs):
return timestep
def get_dtype(self):
return self.diffusion_model.dtype
def is_adm(self):
return self.adm_channels > 0
def encode_adm(self, **kwargs):
return None
@ -183,14 +189,20 @@ class BaseModel(torch.nn.Module):
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5:
if concat_latent_image.shape[-3] < noise.shape[-3]:
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
else:
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if denoise_mask is not None:
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask[:, :1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
num_dim = noise.ndim - 2
denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
@ -200,12 +212,21 @@ class BaseModel(torch.nn.Module):
if ck == "mask":
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
elif ck == "mask_inverted":
cond_concat.append(1.0 - denoise_mask.to(device))
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
cond_concat.append(torch.ones_like(noise)[:, :1])
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
elif ck == "mask_inverted":
cond_concat.append(torch.zeros_like(noise)[:, :1])
if ck == "concat_image":
if concat_latent_image is not None:
cond_concat.append(concat_latent_image.to(device))
else:
cond_concat.append(torch.zeros_like(noise))
data = torch.cat(cond_concat, dim=1)
return data
return None
@ -293,6 +314,9 @@ class BaseModel(torch.nn.Module):
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@ -540,6 +564,10 @@ class SD_X4Upscaler(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class IP2P:
@ -787,7 +815,7 @@ class Flux(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
@ -797,7 +825,10 @@ class Flux(BaseModel):
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class GenmoMochi(BaseModel):
@ -828,17 +859,26 @@ class LTXV(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@ -854,5 +894,122 @@ class HunyuanVideo(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
guiding_frame_index = kwargs.get("guiding_frame_index", None)
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
return out
def scale_latent_inpaint(self, latent_image, **kwargs):
return latent_image
class HunyuanVideoI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image", "mask_inverted")
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",)
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
self.image_to_video = image_to_video
if self.image_to_video:
self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video:
return image
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out

View File

@ -1,3 +1,4 @@
import json
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None
def detect_unet_config(state_dict, key_prefix):
def detect_unet_config(state_dict, key_prefix, metadata=None):
state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
@ -136,7 +137,7 @@ def detect_unet_config(state_dict, key_prefix):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = 16
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
@ -239,6 +242,87 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["micro_condition"] = False
return dit_config
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
dit_config = {}
dit_config["image_model"] = "cosmos"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = False
dit_config["pos_emb_interpolation"] = "crop"
dit_config["block_x_format"] = "THWBD"
dit_config["affline_emb_norm"] = True
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 4096:
# 7B
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 32
dit_config["extra_per_block_abs_pos_emb"] = True
dit_config["rope_h_extrapolation_ratio"] = 1.0
dit_config["rope_w_extrapolation_ratio"] = 1.0
dit_config["rope_t_extrapolation_ratio"] = 2.0
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
else: # 5120
# 14B
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
dit_config["extra_per_block_abs_pos_emb"] = True
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 2.0
dit_config["extra_h_extrapolation_ratio"] = 2.0
dit_config["extra_w_extrapolation_ratio"] = 2.0
dit_config["extra_t_extrapolation_ratio"] = 2.0
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
dit_config["dim"] = dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
dit_config["window_size"] = (-1, -1)
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -373,8 +457,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
logging.error("no match {}".format(unet_config))
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
@ -387,12 +471,17 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
return model_config
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models
"net.", #cosmos
]
counts = {k: 0 for k in candidates}
for k in state_dict:
@ -576,7 +665,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]

View File

@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import platform
@ -50,7 +50,9 @@ xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except:
pass
@ -93,6 +95,13 @@ try:
except:
npu_available = False
try:
import torch_mlu # noqa: F401
_ = torch.mlu.device_count()
mlu_available = torch.mlu.is_available()
except:
mlu_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -110,6 +119,12 @@ def is_ascend_npu():
return True
return False
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@ -125,6 +140,8 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
else:
return torch.device(torch.cuda.current_device())
@ -151,6 +168,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -163,12 +186,21 @@ def get_total_memory(dev=None, torch_total_too=False):
else:
return mem_total
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch_version))
mac_ver = mac_version()
if mac_ver is not None:
logging.info("Mac Version {}".format(mac_ver))
except:
pass
@ -218,7 +250,7 @@ def is_amd():
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.2
MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
@ -227,22 +259,45 @@ if args.use_pytorch_cross_attention:
try:
if is_nvidia():
if int(torch_version[0]) >= 2:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu():
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
except:
pass
try:
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@ -256,15 +311,10 @@ elif args.highvram or args.gpu_only:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
FORCE_FP16 = False
if args.force_fp32:
logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if args.force_fp16:
logging.info("Forcing FP16.")
FORCE_FP16 = True
if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
@ -297,6 +347,8 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
elif is_mlu():
return "{} {}".format(device, torch.mlu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -535,14 +587,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@ -635,7 +684,7 @@ def unet_inital_load_device(parameters, dtype):
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
@ -653,10 +702,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
except:
pass
@ -668,6 +715,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
@ -700,6 +751,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@ -876,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled():
return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled():
global directml_enabled
global cpu_state
@ -885,6 +942,8 @@ def xformers_enabled():
return False
if is_ascend_npu():
return False
if is_mlu():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@ -901,6 +960,11 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_enabled_vae():
if is_amd():
return False # enabling pytorch attention on AMD currently causes crash when doing high res
return pytorch_attention_enabled()
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
@ -911,23 +975,21 @@ def pytorch_attention_flash_attention():
return True
if is_ascend_npu():
return True
if is_mlu():
return True
if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
upcast = True
if upcast:
return torch.float32
return {torch.float16: torch.float32}
else:
return None
@ -957,6 +1019,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -993,21 +1062,26 @@ def is_device_mps(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
def is_directml_enabled():
global directml_enabled
if directml_enabled:
return True
return False
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device):
return False
if FORCE_FP16:
if args.force_fp16:
return True
if FORCE_FP32:
return False
if directml_enabled:
return False
if is_directml_enabled():
return True
if (device is not None and is_device_mps(device)) or mps_mode():
return True
@ -1021,6 +1095,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True
if is_mlu():
return True
if torch.version.hip:
return True
@ -1078,13 +1155,28 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if is_ascend_npu():
return True
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False
props = torch.cuda.get_device_properties(device)
if is_mlu():
if props.major > 3:
return True
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
if bf16_works and manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@ -1103,11 +1195,11 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
return False
return True
@ -1121,18 +1213,13 @@ def soft_empty_cache(force=False):
elif is_ascend_npu():
torch.npu.empty_cache()
elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove
logging.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight
#TODO: might be cleaner to put this somewhere else
import threading

View File

@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
if hasattr(m, "weight_function"):
m.weight_function = []
if hasattr(m, "bias_function"):
m.bias_function = []
def move_weight_functions(m, device):
if device is None:
return 0
memory = 0
if hasattr(m, "weight_function"):
for f in m.weight_function:
if hasattr(f, "move_to"):
memory += f.move_to(device=device)
if hasattr(m, "bias_function"):
for f in m.bias_function:
if hasattr(f, "move_to"):
memory += f.move_to(device=device)
return memory
class LowVramPatch:
def __init__(self, key, patches):
@ -141,7 +161,7 @@ class AutoPatcherEjector:
self.was_injected = False
self.prev_skip_injection = False
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
@ -164,7 +184,7 @@ class MemoryCounter:
self.value = initial
self.minimum = minimum
# TODO: add a safe limit besides 0
def use(self, weight: torch.Tensor):
weight_size = weight.nelement() * weight.element_size()
if self.is_useable(weight_size):
@ -192,11 +212,13 @@ class ModelPatcher:
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
@ -210,7 +232,7 @@ class ModelPatcher:
self.injections: dict[str, list[PatcherInjection]] = {}
self.hook_patches: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
@ -250,11 +272,14 @@ class ModelPatcher:
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.force_cast_weights = self.force_cast_weights
# attachments
n.attachments = {}
for k in self.attachments:
@ -282,7 +307,7 @@ class ModelPatcher:
n.injections[k] = i.copy()
# hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
for group in self.cached_hook_patches:
n.cached_hook_patches[group] = {}
for k in self.cached_hook_patches[group]:
@ -402,7 +427,30 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
def set_model_compute_dtype(self, dtype):
self.add_object_patch("manual_cast_dtype", dtype)
if dtype is not None:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
Args:
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
patcher = ModelPatcher()
weight = patcher.get_model_object("layer1.conv.weight")
"""
if name in self.object_patches:
return self.object_patches[name]
else:
@ -553,6 +601,9 @@ class ModelPatcher:
lowvram_weight = False
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
@ -560,34 +611,46 @@ class ModelPatcher:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
cast_weight = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
m.bias_function = []
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self.patches)
m.weight_function = [LowVramPatch(weight_key, self.patches)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self.patches)
m.bias_function = [LowVramPatch(bias_key, self.patches)]
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
cast_weight = True
else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
if weight_key in self.weight_wrapper_patches:
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
if bias_key in self.weight_wrapper_patches:
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
mem_counter += move_weight_functions(m, device_to)
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
@ -649,6 +712,7 @@ class ModelPatcher:
self.unpatch_hooks()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.model.model_lowvram = False
@ -683,6 +747,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
@ -706,24 +771,32 @@ class ModelPatcher:
move_weight = False
break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
m.weight_function.append(LowVramPatch(weight_key, self.patches))
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
cast_weight = True
if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
@ -789,7 +862,7 @@ class ModelPatcher:
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def remove_callbacks_with_key(self, call_type: str, key: str):
c = self.callbacks.get(call_type, {})
if key in c:
@ -797,7 +870,7 @@ class ModelPatcher:
def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, [])
def get_all_callbacks(self, call_type: str):
c_list = []
for c in self.callbacks.get(call_type, {}).values():
@ -810,7 +883,7 @@ class ModelPatcher:
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {})
if key in w:
@ -831,7 +904,7 @@ class ModelPatcher:
def remove_attachments(self, key: str):
if key in self.attachments:
self.attachments.pop(key)
def get_attachment(self, key: str):
return self.attachments.get(key, None)
@ -842,6 +915,9 @@ class ModelPatcher:
if key in self.injections:
self.injections.pop(key)
def get_injections(self, key: str):
return self.injections.get(key, None)
def set_additional_models(self, key: str, models: list['ModelPatcher']):
self.additional_models[key] = models
@ -851,7 +927,7 @@ class ModelPatcher:
def get_additional_models_with_key(self, key: str):
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
for models in self.additional_models.values():
@ -906,24 +982,25 @@ class ModelPatcher:
self.model.current_patcher = self
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
def restore_hook_patches(self):
if len(self.hook_patches_backup) > 0:
if self.hook_patches_backup is not None:
self.hook_patches = self.hook_patches_backup
self.hook_patches_backup = {}
self.hook_patches_backup = None
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
@ -939,25 +1016,26 @@ class ModelPatcher:
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
hook.add_hook_patches(self, model_options, target, registered_hooks)
if registered is None:
registered = comfy.hooks.HookGroup()
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
else:
registered.add(hook)
if len(weight_hooks_to_register) > 0:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target, registered_hooks)
hook.add_hook_patches(self, model_options, target_dict, registered)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)
callback(self, hooks, target_dict, model_options, registered)
return registered
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():
@ -975,7 +1053,7 @@ class ModelPatcher:
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hook_patches.get(key, [])
@ -1008,15 +1086,14 @@ class ModelPatcher:
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
self.unpatch_hooks()
if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None
@ -1027,12 +1104,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights:
if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
@ -1043,6 +1124,8 @@ class ModelPatcher:
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@ -1063,7 +1146,7 @@ class ModelPatcher:
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
weight, set_func, convert_func = get_key_weight(self.model, key)
weight: torch.Tensor
if key not in self.hook_backup:
@ -1098,18 +1181,24 @@ class ModelPatcher:
del temp_weight
del out_weight
del weight
def unpatch_hooks(self) -> None:
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
self.current_hooks = None
return
keys = list(self.hook_backup.keys())
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None
if whitelist_keys_set:
for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None
def clean_hooks(self):
self.unpatch_hooks()

View File

@ -31,6 +31,7 @@ class EPS:
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@ -61,9 +62,11 @@ class CONST:
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module):

View File

@ -17,8 +17,9 @@
"""
import torch
import logging
import comfy.model_management
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@ -38,21 +39,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = s.bias_function is not None
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
for f in s.bias_function:
bias = f(bias)
has_function = s.weight_function is not None
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
for f in s.weight_function:
weight = f(weight)
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None
weight_function = []
bias_function = []
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@ -64,7 +67,7 @@ class disable_weight_init:
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -78,7 +81,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -92,7 +95,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -106,7 +109,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -120,12 +123,11 @@ class disable_weight_init:
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -139,7 +141,7 @@ class disable_weight_init:
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -160,7 +162,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -181,7 +183,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -199,7 +201,7 @@ class disable_weight_init:
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
@ -307,6 +309,7 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
@ -357,9 +360,13 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
if (
fp8_compute and
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
not disable_fast_fp8
):
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:

View File

@ -96,12 +96,12 @@ class WrapperExecutor:
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
@ -121,7 +121,7 @@ class WrapperExecutor:
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)

View File

@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None):
return noises
def fix_empty_latent_channels(model, latent_image):
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):

View File

@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
if 'control' in c:
cnets.append(c['control'])
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
return hooks_dict
return full_hooks
def convert_cond(cond):
out = []
@ -61,7 +58,6 @@ def convert_cond(cond):
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
@ -73,13 +69,11 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets)
@ -90,11 +84,20 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
models = control_models + gligen + add_models
return models, inference_memory
def get_additional_models_from_model_options(model_options: dict[str]=None):
"""loads additional models from registered AddModels hooks"""
models = []
if model_options is not None and "registered_hooks" in model_options:
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
models.extend(hook.models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
@ -102,9 +105,10 @@ def cleanup_additional_models(models):
m.cleanup()
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
real_model: 'BaseModel' = None
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
@ -123,12 +127,35 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
'''
Registers hooks from conds.
'''
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
hooks = comfy.hooks.HookGroup()
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
# handle all TransformerOptionsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook: comfy.hooks.TransformerOptionsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all AddModelsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all WeightHooks by registering on ModelPatcher
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
# add registered_hooks onto model_options for further reference
if len(registered) > 0:
model_options["registered_hooks"] = registered
# merge original wrappers and callbacks with hooked wrappers and callbacks
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
for wc_name in ["wrappers", "callbacks"]:
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options

View File

@ -1,17 +1,17 @@
from __future__ import annotations
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, NamedTuple
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
import torch
from functools import partial
import collections
from comfy import model_management
import math
import logging
import comfy.samplers
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
@ -19,6 +19,12 @@ import comfy.hooks
import scipy.stats
import numpy
def add_area_dims(area, num_dims):
while (len(area) // 2) < num_dims:
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
return area
def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:])
area = None
@ -34,6 +40,10 @@ def get_area_and_mult(conds, x_in, timestep_in):
return None
if 'area' in conds:
area = list(conds['area'])
area = add_area_dims(area, len(dims))
if (len(area) // 2) > len(dims):
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
if 'strength' in conds:
strength = conds['strength']
@ -50,7 +60,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1:] == x_in.shape[2:])
assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@ -64,16 +74,17 @@ def get_area_and_mult(conds, x_in, timestep_in):
mult = mask * strength
if 'mask' not in conds and area is not None:
rr = 8
fuzz = 8
for i in range(len(dims)):
rr = min(fuzz, mult.shape[2 + i] // 4)
if area[len(dims) + i] != 0:
for t in range(rr):
m = mult.narrow(i + 2, t, 1)
m *= ((1.0/rr) * (t + 1))
m *= ((1.0 / rr) * (t + 1))
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
for t in range(rr):
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
m *= ((1.0/rr) * (t + 1))
m *= ((1.0 / rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
@ -144,7 +155,7 @@ def cond_cat(c_list):
return out
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
# need to figure out remaining unmasked area for conds
default_mults = []
for _ in default_conds:
@ -177,13 +188,13 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
p = p._replace(mult=mult)
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
@ -214,17 +225,17 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
default_c.append(x)
has_default_conds = True
continue
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
@ -375,7 +386,7 @@ class KSamplerX0Inpaint:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
@ -467,6 +478,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
sigmas = adj_idxs.new_zeros(n + 1)
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
return sigmas
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -541,25 +559,37 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c:
return
def area_inside(a, area_cmp):
a = add_area_dims(a, len(area_cmp) // 2)
area_cmp = add_area_dims(area_cmp, len(a) // 2)
a_l = len(a) // 2
area_cmp_l = len(area_cmp) // 2
for i in range(min(a_l, area_cmp_l)):
if a[a_l + i] < area_cmp[area_cmp_l + i]:
return False
for i in range(min(a_l, area_cmp_l)):
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
return False
return True
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x:
a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest:
smallest = x
else:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
if area_inside(c_area, a):
if smallest is None:
smallest = x
elif 'area' not in smallest:
smallest = x
else:
if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
smallest = x
else:
if smallest is None:
smallest = x
@ -679,7 +709,8 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "er_sde"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@ -802,6 +833,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
for cond in conds_to_modify:
cond['hooks'] = hooks
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
HookGroups that have the same reference.'''
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
# if None were registered, make sure all hooks are cleaned from conds
if registered is None:
for k in conds:
for kk in conds[k]:
kk.pop('hooks', None)
return
# find conds that contain hooks to be replaced - group by common HookGroup refs
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
for k in conds:
for kk in conds[k]:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
if hooks is not None:
if not hooks.is_subset_of(registered):
to_replace = hook_replacement.setdefault(hooks, [])
to_replace.append(kk)
# for each hook to replace, create a new proper HookGroup and assign to all common conds
for hooks, conds_to_modify in hook_replacement.items():
new_hooks = hooks.new_with_common_hooks(registered)
if len(new_hooks) == 0:
new_hooks = None
for kk in conds_to_modify:
kk['hooks'] = new_hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
@ -811,9 +869,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
return len(hooks_set)
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
'''
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
'''
if model_options is None:
return
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
casts = []
if device is not None:
casts.append(device)
if dtype is not None:
casts.append(dtype)
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
if hasattr(wc_list[i], "to"):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@ -840,7 +947,9 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
extra_args = {"model_options": extra_model_options, "seed": seed}
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample,
@ -851,7 +960,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
@ -860,6 +969,7 @@ class CFGGuider:
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
self.model_patcher.pre_run()
@ -889,6 +999,7 @@ class CFGGuider:
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
filter_registered_hooks_on_conds(self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
@ -896,6 +1007,7 @@ class CFGGuider:
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
@ -911,29 +1023,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
class SchedulerHandler(NamedTuple):
handler: Callable[..., torch.Tensor]
# Boolean indicates whether to call the handler like:
# scheduler_function(model_sampling, steps) or
# scheduler_function(n, sigma_min: float, sigma_max: float)
use_ms: bool = True
SCHEDULER_HANDLERS = {
"normal": SchedulerHandler(normal_scheduler),
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"simple": SchedulerHandler(simple_scheduler),
"ddim_uniform": SchedulerHandler(ddim_scheduler),
"beta": SchedulerHandler(beta_scheduler),
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
}
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
handler = SCHEDULER_HANDLERS.get(scheduler_name)
if handler is None:
err = f"error invalid scheduler {scheduler_name}"
logging.error(err)
raise ValueError(err)
if handler.use_ms:
return handler.handler(model_sampling, steps)
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
def sampler_object(name):
if name == "uni_pc":

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import json
import torch
from enum import Enum
import logging
@ -11,6 +12,8 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import yaml
import math
@ -34,6 +37,9 @@ import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.model_patcher
import comfy.lora
@ -111,7 +117,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self):
n = CLIP(no_init=True)
@ -129,8 +135,8 @@ class CLIP:
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def tokenize(self, text, return_word_ids=False, **kwargs):
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
@ -244,7 +250,7 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None):
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@ -352,7 +358,12 @@ class VAE:
version = 0
elif tensor_conv1.shape[0] == 1024:
version = 1
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
vae_config = None
if metadata is not None and "config" in metadata:
vae_config = json.loads(metadata["config"]).get("vae", None)
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
@ -376,6 +387,31 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
self.upscale_index_formula = (8, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
self.downscale_index_formula = (8, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -404,6 +440,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
@ -459,6 +499,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
self.throw_exception_if_invalid()
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@ -489,6 +530,7 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
@ -517,9 +559,10 @@ class VAE:
return output.movedim(1, -1)
def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
if self.latent_dim == 3 and pixel_samples.ndim < 5:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@ -549,6 +592,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
@ -641,6 +685,9 @@ class CLIPType(Enum):
LTXV = 8
HUNYUAN_VIDEO = 9
PIXART = 10
COSMOS = 11
LUMINA2 = 12
WAN = 13
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -658,6 +705,8 @@ class TEModel(Enum):
T5_XL = 5
T5_BASE = 6
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -672,8 +721,12 @@ def detect_te_model(sd):
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@ -681,9 +734,10 @@ def detect_te_model(sd):
def t5xxl_detect(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
for sd in clip_data:
if weight_name in sd:
if weight_name in sd or weight_name_old in sd:
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
return {}
@ -710,6 +764,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
tokenizer_data = {}
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
@ -737,15 +792,26 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
elif clip_type == CLIPType.WAN:
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
elif te_model == TEModel.T5_XXL_OLD:
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
@ -775,7 +841,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
@ -822,13 +887,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
clip = None
clipvision = None
vae = None
@ -840,19 +905,24 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
return None
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if model_config.scaled_fp8 is not None:
weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@ -869,7 +939,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
clip_target = model_config.clip_target(state_dict=sd)
@ -898,7 +968,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision)
@ -943,11 +1013,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if model_config.scaled_fp8 is not None:
weight_dtype = None
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
else:
unet_dtype = dtype

View File

@ -158,71 +158,93 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
embedding_weights = []
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
embeds_out = []
attention_masks = []
num_tokens = []
for x in tokens:
attention_mask = []
tokens_temp = []
other_embeds = []
eos = False
index = 0
for y in x:
if isinstance(y, numbers.Integral):
tokens_temp += [int(y)]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
if eos:
attention_mask.append(0)
else:
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
attention_mask.append(1)
token = int(y)
tokens_temp += [token]
if not eos and token == cmp_token:
if end_token is None:
attention_mask[-1] = 0
eos = True
else:
other_embeds.append((index, y))
index += 1
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
self.transformer.set_input_embeddings(new_embedding)
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
return processed_tokens
if emb is None:
index += -1
continue
ind = index + o[0]
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
emb_shape = emb.shape[1]
if emb.shape[-1] == tokens_embed.shape[-1]:
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
else:
index += -1
pad_extra += emb_shape
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
if pad_extra > 0:
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
attention_mask = attention_mask + [0] * pad_extra
embeds_out.append(tokens_embed)
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
self.transformer.set_input_embeddings(backup_embeds)
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
if self.layer == "last":
z = outputs[0].float()
@ -388,13 +410,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
if 'weights_only' in torch.load.__code__.co_varnames:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
except:
embed_out = safe_load_embed_zip(embed_path)
else:
embed = torch.load(embed_path, map_location="cpu")
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
except:
embed_out = safe_load_embed_zip(embed_path)
except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None
@ -424,10 +443,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = max_length
self.min_length = min_length
self.end_token = None
@ -485,7 +504,7 @@ class SDTokenizer:
return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
@ -588,13 +607,18 @@ class SDTokenizer:
return {}
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
self.clip_name = name
self.clip = "{}".format(self.clip_name)
else:
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out
@ -603,7 +627,7 @@ class SD1Tokenizer:
return getattr(self, self.clip).untokenize(token_weight_pair)
def state_dict(self):
return {}
return getattr(self, self.clip).state_dict()
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@ -26,7 +26,7 @@ class SDXLTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@ -14,6 +14,9 @@ import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
from . import supported_models_base
from . import latent_formats
@ -759,7 +762,7 @@ class LTXV(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 2.7
memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -787,7 +790,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
memory_usage_factor = 1.8 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -823,6 +826,139 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
class HunyuanVideoI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 33,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoI2V(self, device=device)
return out
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
return out
class CosmosT2V(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
"in_channels": 16,
}
sampling_settings = {
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosVideo(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class CosmosI2V(CosmosT2V):
unet_config = {
"image_model": "cosmos",
"in_channels": 17,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
class Lumina2(supported_models_base.BASE):
unet_config = {
"image_model": "lumina2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 6.0,
}
memory_usage_factor = 1.2
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Lumina2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
}
sampling_settings = {
"shift": 8.0,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
models += [SVD_img2vid]

View File

@ -93,8 +93,11 @@ class BertEmbeddings(torch.nn.Module):
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, input_tokens, token_type_ids=None, dtype=None):
x = self.word_embeddings(input_tokens, out_dtype=dtype)
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
@ -113,12 +116,12 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
x, i = self.encoder(x, mask, intermediate_output)
return x, i

View File

@ -0,0 +1,42 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import os
from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
class CosmosT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_

View File

@ -18,7 +18,7 @@ class FluxTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)

View File

@ -4,6 +4,7 @@ import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast
import torch
import os
import numbers
def llama_detect(state_dict, prefix=""):
@ -22,7 +23,7 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
@ -38,15 +39,26 @@ class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
llama_text = "{}{}".format(self.llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
embed_count = 0
for r in llama_text_tokens:
for i in range(len(r)):
if r[i][0] == 128257:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
embed_count += 1
out["llama"] = llama_text_tokens
return out
def untokenize(self, token_weight_pair):
@ -80,20 +92,51 @@ class HunyuanVideoClipModel(torch.nn.Module):
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0
for i, v in enumerate(token_weight_pairs_llama[0]):
if v[0] == 128007: # <|end_header_id|>
template_end = i
extra_template_end = 0
extra_sizes = 0
user_end = 9999999999999
images = []
tok_pairs = token_weight_pairs_llama[0]
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 128006:
if tok_pairs[i + 1][0] == 882:
if tok_pairs[i + 2][0] == 128007:
template_end = i + 2
user_end = -1
if elem == 128009 and user_end == -1:
user_end = i + 1
else:
if elem.get("original_type") == "image":
elem_size = elem.get("data").shape[0]
if template_end > 0:
if user_end == -1:
extra_template_end += elem_size - 1
else:
image_start = i + extra_sizes
image_end = i + elem_size + extra_sizes
images.append((image_start, image_end, elem.get("image_interleave", 1)))
extra_sizes += elem_size - 1
if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
if tok_pairs[template_end + 1][0] == 271:
template_end += 2
llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
if len(images) > 0:
out = []
for i in images:
out.append(llama_out[:, i[0]: i[1]: i[2]])
llama_output = torch.cat(out + [llama_output], dim=1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_out, l_pooled, llama_extra_out
return llama_output, l_pooled, llama_extra_out
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:

View File

@ -37,7 +37,7 @@ class HyditTokenizer:
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)

View File

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
@ -21,15 +20,41 @@ class Llama2Config:
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 500000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
@dataclass
class Gemma2_2B_Config:
vocab_size: int = 256000
hidden_size: int = 2304
intermediate_size: int = 9216
num_hidden_layers: int = 26
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
transformer_type: str = "gemma2"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
self.add = add
def forward(self, x: torch.Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
w = self.weight
if self.add:
w = w + 1.0
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
def rotate_half(x):
@ -68,13 +93,15 @@ class Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = config.head_dim
self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
self,
@ -84,7 +111,6 @@ class Attention(nn.Module):
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@ -108,9 +134,13 @@ class MLP(nn.Module):
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@ -146,6 +176,45 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = self.post_attention_layernorm(x)
x = residual + x
# MLP
residual = x
x = self.pre_feedforward_layernorm(x)
x = self.mlp(x)
x = self.post_feedforward_layernorm(x)
x = residual + x
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -158,17 +227,30 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
transformer = TransformerBlock
self.normalize_in = False
self.layers = nn.ModuleList([
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embed_tokens(x, out_dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
freqs_cis = precompute_freqs_cis(self.config.head_dim,
x.shape[1],
self.config.rope_theta,
device=x.device)
@ -206,16 +288,7 @@ class Llama2_(nn.Module):
return x, intermediate
class Llama2(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Llama2Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class BaseLlama:
def get_input_embeddings(self):
return self.model.embed_tokens
@ -224,3 +297,23 @@ class Llama2(torch.nn.Module):
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class Llama2(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Llama2Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma2_2B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@ -0,0 +1,39 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return LuminaTEModel_

View File

@ -43,7 +43,7 @@ class SD3Tokenizer:
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@ -1,21 +1,21 @@
import torch
class SPieceTokenizer:
add_eos = True
@staticmethod
def from_pretrained(path):
return SPieceTokenizer(path)
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path):
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
self.add_bos = add_bos
self.add_eos = add_eos
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
if isinstance(tokenizer_path, bytes):
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
else:
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
def get_vocab(self):
out = {}

View File

@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
@ -227,8 +227,9 @@ class T5(torch.nn.Module):
super().__init__()
self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"]
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
@ -238,8 +239,11 @@ class T5(torch.nn.Module):
def set_input_embeddings(self, embeddings):
self.shared = embeddings
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
if input_ids is None:
x = embeds
else:
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs)
return self.encoder(x, attention_mask=attention_mask, **kwargs)

View File

@ -0,0 +1,22 @@
{
"d_ff": 65536,
"d_kv": 128,
"d_model": 1024,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": false,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 128,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}

View File

@ -0,0 +1,22 @@
{
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 256384
}

View File

@ -0,0 +1,37 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
import os
class UMT5XXlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return WanTEModel

View File

@ -29,17 +29,45 @@ import itertools
from torch.nn.functional import interpolate
from einops import rearrange
def load_torch_file(ckpt, safe_load=False, device=None):
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
if "MetadataIncompleteBuffer" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
@ -55,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
sd = pl_sd
else:
sd = pl_sd
return sd
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:
@ -693,7 +721,25 @@ def copy_to_param(obj, attr, value):
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.
Args:
obj: The object to get the attribute from
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
model = MyModel()
weight = get_attr(model, "layer1.conv.weight")
# Equivalent to: model.layer1.conv.weight
Important:
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
accessing nested model objects under `ModelPatcher.model`.
"""
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
@ -727,7 +773,7 @@ def bislerp(samples, width, height):
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
@ -893,7 +939,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s

View File

@ -54,8 +54,8 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
def get_input_info(class_def, input_name, valid_inputs=None):
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
@ -131,7 +131,7 @@ class TopologicalSort:
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import torchaudio
import torch
import comfy.model_management
@ -10,6 +12,7 @@ import random
import hashlib
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types import FileLocator
class EmptyLatentAudio:
def __init__(self):
@ -164,7 +167,7 @@ class SaveAudio:
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
results: list[FileLocator] = []
metadata = {}
if not args.disable_metadata:

View File

@ -0,0 +1,82 @@
import nodes
import torch
import comfy.model_management
import comfy.utils
class EmptyCosmosLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
pixel_len = min(pixels.shape[0], length)
padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
padded_pixels[:pixel_len] = pixels[:pixel_len]
latent_len = ((pixel_len - 1) // 8) + 1
latent_temp = vae.encode(padded_pixels)
return latent_temp[:, :, :latent_len]
class CosmosImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
if end_image is not None:
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
}

View File

@ -231,6 +231,24 @@ class FlipSigmas:
sigmas[0] = 0.0001
return (sigmas,)
class SetFirstSigma:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
"sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "set_first_sigma"
def set_first_sigma(self, sigmas, sigma):
sigmas = sigmas.clone()
sigmas[0] = sigma
return (sigmas, )
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -436,7 +454,7 @@ class SamplerCustom:
return {"required":
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
@ -587,10 +605,16 @@ class DisableNoise:
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
}
}
return {
"required": {
"noise_seed": ("INT", {
"default": 0,
"min": 0,
"max": 0xffffffffffffffff,
"control_after_generate": True,
}),
}
}
def get_noise(self, noise_seed):
return (Noise_RandomNoise(noise_seed),)
@ -710,6 +734,7 @@ NODE_CLASS_MAPPINGS = {
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,

View File

@ -38,7 +38,26 @@ class FluxGuidance:
return (c, )
class FluxDisableGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
}

Some files were not shown because too many files have changed in this diff Show More