Merge branch 'comfyanonymous:master' into sa_solver

This commit is contained in:
chaObserv 2025-01-06 18:09:03 +08:00 committed by GitHub
commit 812dc34f46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
238 changed files with 674399 additions and 115970 deletions

View File

@ -28,17 +28,17 @@ def pull(repo, remote_name='origin', branch='master'):
if repo.index.conflicts is not None: if repo.index.conflicts is not None:
for conflict in repo.index.conflicts: for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path) print('Conflicts found in:', conflict[0].path) # noqa: T201
raise AssertionError('Conflicts, ahhhhh!!') raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature user = repo.default_signature
tree = repo.index.write_tree() tree = repo.index.write_tree()
commit = repo.create_commit('HEAD', repo.create_commit('HEAD',
user, user,
user, user,
'Merge!', 'Merge!',
tree, tree,
[repo.head.target, remote_master_id]) [repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging. # We need to do this or git CLI will think we are still merging.
repo.state_cleanup() repo.state_cleanup()
else: else:
@ -49,18 +49,18 @@ repo_path = str(sys.argv[1])
repo = pygit2.Repository(repo_path) repo = pygit2.Repository(repo_path)
ident = pygit2.Signature('comfyui', 'comfy@ui') ident = pygit2.Signature('comfyui', 'comfy@ui')
try: try:
print("stashing current changes") print("stashing current changes") # noqa: T201
repo.stash(ident) repo.stash(ident)
except KeyError: except KeyError:
print("nothing to stash") print("nothing to stash") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try: try:
repo.branches.local.create(backup_branch_name, repo.head.peel()) repo.branches.local.create(backup_branch_name, repo.head.peel())
except: except:
pass pass
print("checking out master branch") print("checking out master branch") # noqa: T201
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')
if branch is None: if branch is None:
ref = repo.lookup_reference('refs/remotes/origin/master') ref = repo.lookup_reference('refs/remotes/origin/master')
@ -72,7 +72,7 @@ else:
ref = repo.lookup_reference(branch.name) ref = repo.lookup_reference(branch.name)
repo.checkout(ref) repo.checkout(ref)
print("pulling latest changes") print("pulling latest changes") # noqa: T201
pull(repo) pull(repo)
if "--stable" in sys.argv: if "--stable" in sys.argv:
@ -94,7 +94,7 @@ if "--stable" in sys.argv:
if latest_tag is not None: if latest_tag is not None:
repo.checkout(latest_tag) repo.checkout(latest_tag)
print("Done!") print("Done!") # noqa: T201
self_update = True self_update = True
if len(sys.argv) > 2: if len(sys.argv) > 2:

View File

@ -3,8 +3,8 @@ name: Python Linting
on: [push, pull_request] on: [push, pull_request]
jobs: jobs:
pylint: ruff:
name: Run Pylint name: Run Ruff
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -16,8 +16,8 @@ jobs:
with: with:
python-version: 3.x python-version: 3.x
- name: Install Pylint - name: Install Ruff
run: pip install pylint run: pip install ruff
- name: Run Pylint - name: Run Ruff
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py") run: ruff check .

View File

@ -20,7 +20,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
os: [macos, linux, windows] # os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"] cuda_version: ["12.1"]
torch_version: ["stable"] torch_version: ["stable"]
@ -31,9 +32,9 @@ jobs:
- os: linux - os: linux
runner_label: [self-hosted, Linux] runner_label: [self-hosted, Linux]
flags: "" flags: ""
- os: windows # - os: windows
runner_label: [self-hosted, Windows] # runner_label: [self-hosted, Windows]
flags: "" # flags: ""
runs-on: ${{ matrix.runner_label }} runs-on: ${{ matrix.runner_label }}
steps: steps:
- name: Test Workflows - name: Test Workflows
@ -45,28 +46,28 @@ jobs:
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }} comfyui_flags: ${{ matrix.flags }}
test-win-nightly: # test-win-nightly:
strategy: # strategy:
fail-fast: true # fail-fast: true
matrix: # matrix:
os: [windows] # os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"] # python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"] # cuda_version: ["12.1"]
torch_version: ["nightly"] # torch_version: ["nightly"]
include: # include:
- os: windows # - os: windows
runner_label: [self-hosted, Windows] # runner_label: [self-hosted, Windows]
flags: "" # flags: ""
runs-on: ${{ matrix.runner_label }} # runs-on: ${{ matrix.runner_label }}
steps: # steps:
- name: Test Workflows # - name: Test Workflows
uses: comfy-org/comfy-action@main # uses: comfy-org/comfy-action@main
with: # with:
os: ${{ matrix.os }} # os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }} # python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }} # torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} # google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }} # comfyui_flags: ${{ matrix.flags }}
test-unix-nightly: test-unix-nightly:
strategy: strategy:

View File

@ -17,7 +17,7 @@ jobs:
path: "ComfyUI" path: "ComfyUI"
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.8' python-version: '3.9'
- name: Install requirements - name: Install requirements
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
@ -28,7 +28,7 @@ jobs:
- name: Start ComfyUI server - name: Start ComfyUI server
run: | run: |
python main.py --cpu 2>&1 | tee console_output.log & python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600 wait-for-it --service 127.0.0.1:8188 -t 30
working-directory: ComfyUI working-directory: ComfyUI
- name: Check for unhandled exceptions in server log - name: Check for unhandled exceptions in server log
run: | run: |

58
.github/workflows/update-frontend.yml vendored Normal file
View File

@ -0,0 +1,58 @@
name: Update Frontend Release
on:
workflow_dispatch:
inputs:
version:
description: "Frontend version to update to (e.g., 1.0.0)"
required: true
type: string
jobs:
update-frontend:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
- name: Start ComfyUI server
run: |
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 30
- name: Configure Git
run: |
git config --global user.name "GitHub Action"
git config --global user.email "action@github.com"
# Replace existing frontend content with the new version and remove .js.map files
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
- name: Update frontend content
run: |
rm -rf web/
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
rm web/**/*.js.map
- name: Create Pull Request
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.PR_BOT_PAT }}
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
title: "Frontend Update: v${{ github.event.inputs.version }}"
body: |
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
This PR was created automatically by the frontend update workflow.
branch: release-${{ github.event.inputs.version }}
base: master
labels: Frontend,dependencies

View File

@ -7,19 +7,19 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "124" default: "126"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
required: true required: true
type: string type: string
default: "12" default: "13"
python_patch: python_patch:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "4" default: "1"
# push: # push:
# branches: # branches:
# - master # - master

View File

@ -1,3 +0,0 @@
[MESSAGES CONTROL]
disable=all
enable=eval-used

View File

@ -1 +1,23 @@
* @comfyanonymous # Admins
* @comfyanonymous
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

142
README.md
View File

@ -28,7 +28,7 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](comfyui_screenshot.png) ![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
</div> </div>
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
@ -38,8 +38,21 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Image Models
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - SD1.x, SD2.x,
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
- Pixart Alpha and Sigma
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram. - Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@ -59,9 +72,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
@ -73,37 +83,39 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation | | Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------| |------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation | | `Ctrl` + `Enter` | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation | | `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation | | `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo | | `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| Ctrl + S | Save workflow | | `Ctrl` + `S` | Save workflow |
| Ctrl + O | Load workflow | | `Ctrl` + `O` | Load workflow |
| Ctrl + A | Select all nodes | | `Ctrl` + `A` | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes | | `Alt `+ `C` | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes | | `Ctrl` + `M` | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes | | `Delete`/`Backspace` | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph | | `Ctrl` + `Backspace` | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor | | `Space` | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection | | `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | | `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | | `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time | | `Shift` + `Drag` | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph | | `Ctrl` + `D` | Load default graph |
| Alt + `+` | Canvas Zoom in | | `Alt` + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out | | `Alt` + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | | `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes | | `P` | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes | | `Ctrl` + `G` | Group selected nodes |
| Q | Toggle visibility of the queue | | `Q` | Toggle visibility of the queue |
| H | Toggle visibility of history | | `H` | Toggle visibility of history |
| R | Refresh graph | | `R` | Refresh graph |
| `F` | Show/Hide menu |
| `.` | Fit view to selection (Whole graph when nothing is selected) |
| Double-Click LMB | Open node quick search palette | | Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once | | `Shift` + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot | | `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users `Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing # Installing
@ -139,11 +151,35 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only) ### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements: This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch nightly, use the following command:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
2. Launch ComfyUI by running `python main.py`
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
```
conda install libuv
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
```
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
### NVIDIA ### NVIDIA
@ -153,7 +189,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements: This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
#### Troubleshooting #### Troubleshooting
@ -173,17 +209,6 @@ After this you should have everything installed and can proceed to running Comfy
### Others: ### Others:
#### Intel GPUs
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
#### Apple Mac silicon #### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
@ -199,6 +224,16 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
# Running # Running
```python main.py``` ```python main.py```
@ -211,6 +246,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py``` For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
# Notes # Notes
Only parts of the graph that have an output with all the correct inputs will be executed. Only parts of the graph that have an output with all the correct inputs will be executed.
@ -296,4 +339,3 @@ This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy
### Which GPU should I buy for this? ### Which GPU should I buy for this?
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI) [See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)

View File

@ -2,6 +2,7 @@ from aiohttp import web
from typing import Optional from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger import app.logger
class InternalRoutes: class InternalRoutes:
@ -9,9 +10,9 @@ class InternalRoutes:
The top level web router for internal routes: /internal/* The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only. The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information. Check README.md for more information.
''' '''
def __init__(self):
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef() self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None self._app: Optional[web.Application] = None
self.file_service = FileService({ self.file_service = FileService({
@ -19,6 +20,8 @@ class InternalRoutes:
"user": user_directory, "user": user_directory,
"output": output_directory "output": output_directory
}) })
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self): def setup_routes(self):
@self.routes.get('/files') @self.routes.get('/files')
@ -34,7 +37,28 @@ class InternalRoutes:
@self.routes.get('/logs') @self.routes.get('/logs')
async def get_logs(request): async def get_logs(request):
return web.json_response(app.logger.get_logs()) return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_raw_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})
@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)
return web.Response(status=200)
@self.routes.get('/folder_paths') @self.routes.get('/folder_paths')
async def get_folder_paths(request): async def get_folder_paths(request):

View File

@ -0,0 +1,60 @@
from app.logger import on_flush
import os
import shutil
class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)
def get_terminal_size(self):
try:
size = os.get_terminal_size()
return (size.columns, size.lines)
except OSError:
try:
size = shutil.get_terminal_size()
return (size.columns, size.lines)
except OSError:
return (80, 24) # fallback to 80x24
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
if lines != self.rows:
self.rows = lines
changed = True
if changed:
return {"cols": self.cols, "rows": self.rows}
return None
def subscribe(self, client_id):
self.subscriptions.add(client_id)
def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)

View File

@ -0,0 +1,34 @@
from __future__ import annotations
import os
import folder_paths
import glob
from aiohttp import web
class CustomNodeManager:
"""
Placeholder to refactor the custom node management features from ComfyUI-Manager.
Currently it only contains the custom workflow templates feature.
"""
def add_routes(self, routes, webapp, loadedModules):
@routes.get("/workflow_templates")
async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
files = [
file
for folder in folder_paths.get_folder_paths("custom_nodes")
for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
]
workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
for file in files:
custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
workflow_name = os.path.splitext(os.path.basename(file))[0]
workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
return web.json_response(workflow_templates_dict)
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
workflows_dir = os.path.join(module_dir, 'example_workflows')
if os.path.exists(workflows_dir):
webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])

View File

@ -1,31 +1,84 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading
logs = None logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") stdout_interceptor = None
stderr_interceptor = None
class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []
def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)
# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)
def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []
def on_flush(self, callback):
self._flush_callbacks.append(callback)
def get_logs(): def get_logs():
return "\n".join([formatter.format(x) for x in logs]) return logs
def setup_logger(log_level: str = 'INFO', capacity: int = 300): def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
global logs global logs
if logs: if logs:
return return
# Override output streams and log to buffer
logs = deque(maxlen=capacity)
global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
# Setup default global logger # Setup default global logger
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(log_level) logger.setLevel(log_level)
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s")) stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer if use_stdout:
logs = deque(maxlen=capacity) # Only errors and critical to stderr
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO) stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter) # Lesser to stdout
logger.addHandler(memory_handler) stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
logger.addHandler(stdout_handler)
logger.addHandler(stream_handler)

184
app/model_manager.py Normal file
View File

@ -0,0 +1,184 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

View File

@ -1,28 +1,45 @@
from __future__ import annotations
import json import json
import os import os
import re import re
import uuid import uuid
import glob import glob
import shutil import shutil
import logging
from aiohttp import web from aiohttp import web
from urllib import parse from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
import folder_paths import folder_paths
from .app_settings import AppSettings from .app_settings import AppSettings
from typing import TypedDict
default_user = "default" default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
}
class UserManager(): class UserManager():
def __init__(self): def __init__(self):
user_directory = folder_paths.get_user_directory() user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self) self.settings = AppSettings(self)
if not os.path.exists(user_directory): if not os.path.exists(user_directory):
os.mkdir(user_directory) os.makedirs(user_directory, exist_ok=True)
if not args.multi_user: if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******") logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user: if args.multi_user:
if os.path.isfile(self.get_users_file()): if os.path.isfile(self.get_users_file()):
@ -154,6 +171,7 @@ class UserManager():
recurse = request.rel_url.query.get('recurse', '').lower() == "true" recurse = request.rel_url.query.get('recurse', '').lower() == "true"
full_info = request.rel_url.query.get('full_info', '').lower() == "true" full_info = request.rel_url.query.get('full_info', '').lower() == "true"
split_path = request.rel_url.query.get('split', '').lower() == "true"
# Use different patterns based on whether we're recursing or not # Use different patterns based on whether we're recursing or not
if recurse: if recurse:
@ -161,26 +179,21 @@ class UserManager():
else: else:
pattern = os.path.join(glob.escape(path), '*') pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse) def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info:
return get_file_info(full_path, path)
if full_info: rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
results = [ if split_path:
{ return [rel_path] + rel_path.split('/')
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true" return rel_path
if split_path and not full_info:
results = [[x] + x.split('/') for x in results] results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results) return web.json_response(results)
@ -208,20 +221,51 @@ class UserManager():
@routes.post("/userdata/{file}") @routes.post("/userdata/{file}")
async def post_userdata(request): async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request) path = get_user_data_path(request)
if not isinstance(path, str): if not isinstance(path, str):
return path return path
overwrite = request.query["overwrite"] != "false" overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(path): if not overwrite and os.path.exists(path):
return web.Response(status=409) return web.Response(status=409, text="File already exists")
body = await request.read() body = await request.read()
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(body) f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None)) user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
return web.json_response(resp) return web.json_response(resp)
@routes.delete("/userdata/{file}") @routes.delete("/userdata/{file}")
@ -236,6 +280,30 @@ class UserManager():
@routes.post("/userdata/{file}/move/{dest}") @routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request): async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True) source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str): if not isinstance(source, str):
return source return source
@ -244,12 +312,19 @@ class UserManager():
if not isinstance(source, str): if not isinstance(source, str):
return dest return dest
overwrite = request.query["overwrite"] != "false" overwrite = request.query.get("overwrite", 'true') != "false"
if not overwrite and os.path.exists(dest): full_info = request.query.get('full_info', 'false').lower() == "true"
return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'") if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest) shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None)) user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
return web.json_response(resp) return web.json_response(resp)

View File

@ -2,11 +2,9 @@
#and modified #and modified
import torch import torch
import torch as th
import torch.nn as nn import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import ( from ..ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding, timestep_embedding,
) )
@ -162,7 +160,6 @@ class ControlNet(nn.Module):
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential": elif self.num_classes == "sequential":
assert adm_in_channels is not None assert adm_in_channels is not None
@ -415,7 +412,6 @@ class ControlNet(nn.Module):
out_output = [] out_output = []
out_middle = [] out_middle = []
hs = []
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)

120
comfy/cldm/dit_embedder.py Normal file
View File

@ -0,0 +1,120 @@
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import Dict, Optional from typing import Optional
import comfy.ldm.modules.diffusionmodules.mmdit import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):

View File

@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group() fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
@ -82,7 +84,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"
@ -102,6 +105,7 @@ attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@ -118,7 +122,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
@ -137,6 +141,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
# The default built-in provider hosted under web/ # The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

View File

@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu, "gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
} }
class CLIPMLP(torch.nn.Module): class CLIPMLP(torch.nn.Module):
@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d( self.patch_embedding = operations.Conv2d(
in_channels=num_channels, in_channels=num_channels,
out_channels=embed_dim, out_channels=embed_dim,
kernel_size=patch_size, kernel_size=patch_size,
stride=patch_size, stride=patch_size,
bias=False, bias=patch_bias,
dtype=dtype, dtype=dtype,
device=device device=device
) )
num_patches = (image_size // patch_size) ** 2 self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values): def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds) if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module): class CLIPVision(torch.nn.Module):
@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"] heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations) self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim) if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim) self.post_layernorm = operations.LayerNorm(embed_dim)
@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x) x = self.pre_layrnorm(x)
#TODO: attention_mask? #TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :]) if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module): class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations) self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs) x = self.vision_model(*args, **kwargs)

View File

@ -16,13 +16,18 @@ class Output:
def __setitem__(self, key, item): def __setitem__(self, key, item):
setattr(self, key, item) setattr(self, key, item)
def clip_preprocess(image, size=224): def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1) image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size): if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3])) if crop:
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2 h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2 w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size] image = image[:,:,h:h+size,w:w+size]
@ -35,6 +40,8 @@ class ClipVisionModel():
config = json.load(f) config = json.load(f)
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@ -49,9 +56,9 @@ class ClipVisionModel():
def get_sd(self): def get_sd(self):
return self.model.state_dict() return self.model.state_dict()
def encode_image(self, image): def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2) out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output() outputs = Output()
@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")

View File

@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -0,0 +1,43 @@
# Comfy Typing
## Type hinting for ComfyUI Node development
This module provides type hinting and concrete convenience types for node developers.
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {"required": {}}
```
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
# Types
A few primary types are documented below. More complete information is available via the docstrings on each type.
## `IO`
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
- `ANY`: `"*"`
- `NUMBER`: `"FLOAT,INT"`
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
## `ComfyNodeABC`
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
### Type hinting for `INPUT_TYPES`
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
### `INPUT_TYPES` return dict
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
### Options for individual inputs
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)

View File

@ -1,5 +1,6 @@
import torch import torch
from typing import Callable, Protocol, TypedDict, Optional, List from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
class UnetApplyFunction(Protocol): class UnetApplyFunction(Protocol):
@ -30,3 +31,15 @@ class UnetParams(TypedDict):
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor] UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
__all__ = [
"UnetWrapperFunction",
UnetApplyConds.__name__,
UnetParams.__name__,
UnetApplyFunction.__name__,
IO.__name__,
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
]

View File

@ -0,0 +1,28 @@
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
* This node is intended as an example for developers only.
"""
DESCRIPTION = cleandoc(__doc__)
CATEGORY = "examples"
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"input_int": (IO.INT, {"defaultInput": True}),
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("input_plus_one",)
FUNCTION = "execute"
def execute(self, input_int: int):
return (input_int + 1,)

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View File

@ -0,0 +1,274 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict
from abc import ABC, abstractmethod
from enum import Enum
class StrEnum(str, Enum):
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
def __str__(self) -> str:
return self.value
class IO(StrEnum):
"""Node input/output data types.
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
"""
STRING = "STRING"
IMAGE = "IMAGE"
MASK = "MASK"
LATENT = "LATENT"
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
GUIDER = "GUIDER"
NOISE = "NOISE"
CLIP = "CLIP"
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
GLIGEN = "GLIGEN"
UPSCALE_MODEL = "UPSCALE_MODEL"
AUDIO = "AUDIO"
WEBCAM = "WEBCAM"
POINT = "POINT"
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
ANY = "*"
"""Always matches any type, but at a price.
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
"""
NUMBER = "FLOAT,INT"
"""A float or an int - could be either"""
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
"""Could be any of: string, float, int, or bool"""
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
"""
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: bool
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: str
"""Tooltip for the input (or widget), shown on pointer hover"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: float
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: float
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: float
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: float
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: bool
"""Use a multiline text box (``STRING``)"""
placeholder: str
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: Literal["PROMPT"]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: Literal["EXTRA_PNGINFO"]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: Literal["DYNPROMPT"]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
"""Describes all inputs that must be connected for the node to execute."""
optional: dict[str, tuple[IO, InputTypeOptions]]
"""Describes inputs which do not need to be connected."""
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
"""
DESCRIPTION: str
"""Node description, shown as a tooltip when hovering over the node.
Usage::
# Explicitly define the description
DESCRIPTION = "Example description here."
# Use the docstring of the node class.
DESCRIPTION = cleandoc(__doc__)
"""
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
@classmethod
@abstractmethod
def INPUT_TYPES(s) -> InputTypeDict:
"""Defines node inputs.
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
"""
return {"required": {}}
OUTPUT_NODE: bool
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
From the docs:
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
RETURN_TYPES: tuple[IO]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
"""
RETURN_NAMES: tuple[str]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
"""
class CheckLazyMixin:
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
def check_lazy_status(self, **kwargs) -> list[str]:
"""Returns a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need

View File

@ -35,6 +35,10 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -78,6 +82,8 @@ class ControlBase:
self.concat_mask = False self.concat_mask = False
self.extra_concat_orig = [] self.extra_concat_orig = []
self.extra_concat = None self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -115,6 +121,14 @@ class ControlBase:
out += self.previous_controlnet.get_models() out += self.previous_controlnet.get_models()
return out return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c): def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
@ -129,6 +143,8 @@ class ControlBase:
c.strength_type = self.strength_type c.strength_type = self.strength_type
c.concat_mask = self.concat_mask c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy() c.extra_concat_orig = self.extra_concat_orig.copy()
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -181,7 +197,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__() super().__init__()
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@ -196,11 +212,12 @@ class ControlNet(ControlBase):
self.extra_conds += extra_conds self.extra_conds += extra_conds
self.strength_type = strength_type self.strength_type = strength_type
self.concat_mask = concat_mask self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None: if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -224,6 +241,7 @@ class ControlNet(ControlBase):
if self.latent_format is not None: if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None: if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True) loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@ -279,7 +297,6 @@ class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None: device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
@ -364,7 +381,6 @@ class ControlLora(ControlNet):
self.control_model.to(comfy.model_management.get_torch_device()) self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict() sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = sd[k] weight = sd[k]
@ -427,6 +443,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model return control_model
def load_controlnet_mmdit(sd, model_options={}): def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@ -448,6 +465,82 @@ def load_controlnet_mmdit(sd, model_options={}):
return control return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}): def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
@ -560,7 +653,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options) return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data: elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data: elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
@ -674,10 +770,10 @@ class T2IAdapter(ControlBase):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height return width, height
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None: if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -725,7 +821,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
for i in range(4): for i in range(4):
for j in range(2): for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j) prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2) prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = "" prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace) t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys() keys = t2i_data.keys()

View File

@ -157,16 +157,23 @@ vae_conversion_map_attn = [
] ]
def reshape_weight_for_sd(w): def reshape_weight_for_sd(w, conv3d=False):
# convert HF linear weights to SD conv2d weights # convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1) if conv3d:
return w.reshape(*w.shape, 1, 1, 1)
else:
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict): def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()} mapping = {k: k for k in vae_state_dict.keys()}
conv3d = False
for k, v in mapping.items(): for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map: for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part) v = v.replace(hf_part, sd_part)
if v.endswith(".conv.weight"):
if not conv3d and vae_state_dict[k].ndim == 5:
conv3d = True
mapping[k] = v mapping[k] = v
for k, v in mapping.items(): for k, v in mapping.items():
if "attentions" in k: if "attentions" in k:
@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format") logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
return new_state_dict return new_state_dict

View File

@ -1,10 +1,10 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified #code taken from: https://github.com/wl-zhao/UniPC and modified
import torch import torch
import torch.nn.functional as F
import math import math
import logging
from tqdm.auto import trange, tqdm from tqdm.auto import trange
class NoiseScheduleVP: class NoiseScheduleVP:
@ -16,7 +16,7 @@ class NoiseScheduleVP:
continuous_beta_0=0.1, continuous_beta_0=0.1,
continuous_beta_1=20., continuous_beta_1=20.,
): ):
"""Create a wrapper class for the forward SDE (VP type). r"""Create a wrapper class for the forward SDE (VP type).
*** ***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
@ -475,7 +475,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule ns = self.noise_schedule
assert order <= len(model_prev_list) assert order <= len(model_prev_list)
@ -519,7 +519,6 @@ class UniPC:
A_p = C_inv_p A_p = C_inv_p
if use_corrector: if use_corrector:
print('using corrector')
C_inv = torch.linalg.inv(C) C_inv = torch.linalg.inv(C)
A_c = C_inv A_c = C_inv
@ -704,7 +703,6 @@ class UniPC:
): ):
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start # t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1 steps = len(timesteps) - 1
if method == 'multistep': if method == 'multistep':
assert steps >= order assert steps >= order

View File

@ -1,3 +1,4 @@
import math
import torch import torch
from torch import nn from torch import nn
from .ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention

704
comfy/hooks.py Normal file
View File

@ -0,0 +1,704 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
import logging
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
class EnumHookMode(enum.Enum):
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return True
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
strength = self._strength_clip
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
weights = self.weights_clip
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
c._strength_model = self._strength_model
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
class HookGroup:
def __init__(self):
self.hooks: list[Hook] = []
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
def contains(self, hook: Hook):
return hook in self.hooks
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
for t_range, keyframe in range_kfs:
all_ranges.append(t_range)
# turn list of ranges into boundaries
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
boundaries_set.add(0.0)
boundaries = sorted(boundaries_set)
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
# with real ranges defined, give appropriate hooks w/ keyframes for each range
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
for t_range in real_ranges:
hooks_schedule = []
for hook, val in scheduled_hooks.items():
keyframe = None
# check if is a keyframe that works for the current t_range
for stored_range, stored_kf in val:
# if stored start is less than current end, then fits - give it assigned keyframe
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
keyframe = stored_kf
break
hooks_schedule.append((hook, keyframe))
scheduled_keyframes.append((t_range, hooks_schedule))
return scheduled_keyframes
def reset(self):
for hook in self.hooks:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class HookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
if self.start_t > max_sigma:
return 0
return self.guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class HookKeyframeGroup:
def __init__(self):
self.keyframes: list[HookKeyframe] = []
self._current_keyframe: HookKeyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self._curr_t = -1.
# properties shadow those of HookWeightsKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_guarantee_steps(self):
for kf in self.keyframes:
if kf.guarantee_steps > 0:
return True
return False
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
max_sigma = torch.max(transformer_options["sample_sigmas"])
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_strength = eval_c.strength
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index and prev_strength != self._current_strength
class InterpolationMethod:
LINEAR = "linear"
EASE_IN = "ease_in"
EASE_OUT = "ease_out"
EASE_IN_OUT = "ease_in_out"
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
@classmethod
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
diff = num_to - num_from
if method == cls.LINEAR:
weights = torch.linspace(num_from, num_to, length)
elif method == cls.EASE_IN:
index = torch.linspace(0, 1, length)
weights = diff * np.power(index, 2) + num_from
elif method == cls.EASE_OUT:
index = torch.linspace(0, 1, length)
weights = diff * (1 - np.power(1 - index, 2)) + num_from
elif method == cls.EASE_IN_OUT:
index = torch.linspace(0, 1, length)
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
else:
raise ValueError(f"Unrecognized interpolation method '{method}'.")
if reverse:
weights = weights.flip(dims=(0,))
return weights
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list: list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
hook.weights = lora
return hook_group
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
patches_model = None
patches_clip = None
if weights_model is not None:
patches_model = {}
for key in weights_model:
patches_model[key] = ("model_as_lora", (weights_model[key],))
if weights_clip is not None:
patches_clip = {}
for key in weights_clip:
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
hook.weights = patches_model
hook.weights_clip = patches_clip
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
if discard_model_sampling:
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
if key.startswith("model_sampling"):
patches_model.pop(key, None)
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
hook_group = HookGroup()
hook = WeightHook()
hook_group.add(hook)
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
logging.warning(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
return cond
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
"end_percent": timestep_range[1]})
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
if mask is None:
return cond
set_area_to_bounds = False
if set_cond_area != 'default':
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(cond, {'mask': mask,
'set_area_to_bounds': set_area_to_bounds,
'mask_strength': strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def combine_with_new_conds(conds: list, new_conds: list):
combined_conds = []
for c, new_c in zip(conds, new_conds):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
# finally, apply mask to conditioning and store
final_conds.append(c)
return final_conds
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds

View File

@ -11,7 +11,6 @@ import numpy as np
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS. # Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80): def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1) vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d

View File

@ -72,8 +72,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
return sigma_down, sigma_up return sigma_down, sigma_up
def default_noise_sampler(x): def default_noise_sampler(x, seed=None):
return lambda sigma, sigma_next: torch.randn_like(x) if seed is not None:
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
else:
generator = None
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
class BatchedBrownianTree: class BatchedBrownianTree:
@ -170,43 +176,50 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method if sigma_down == 0:
dt = sigma_down - sigmas[i] x = denoised
x = x + d * dt else:
if sigmas[i + 1] > 0: d = to_d(x, sigmas[i], denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up # Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad() @torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None): def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# Euler method if sigmas[i + 1] == 0:
sigma_down_i_ratio = sigma_down / sigmas[i] x = denoised
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised else:
if sigmas[i + 1] > 0 and eta > 0: downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if eta > 0:
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x return x
@torch.no_grad() @torch.no_grad()
@ -282,9 +295,13 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad() @torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver second-order steps.""" """Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -308,6 +325,39 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j): def linear_multistep_coeff(order, t, i, j):
if order - 1 > i: if order - 1 > i:
@ -553,7 +603,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp() sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg() t_fn = lambda sigma: sigma.log().neg()
@ -587,7 +638,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log() lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
@ -1221,7 +1273,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0] temp = [0]
def post_cfg_function(args): def post_cfg_function(args):
@ -1247,7 +1300,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0] temp = [0]
def post_cfg_function(args): def post_cfg_function(args):

View File

@ -3,6 +3,7 @@ import torch
class LatentFormat: class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
latent_channels = 4 latent_channels = 4
latent_dimensions = 2
latent_rgb_factors = None latent_rgb_factors = None
latent_rgb_factors_bias = None latent_rgb_factors_bias = None
taesd_decoder_name = None taesd_decoder_name = None
@ -143,6 +144,7 @@ class SD3(LatentFormat):
class StableAudio1(LatentFormat): class StableAudio1(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1
class Flux(SD3): class Flux(SD3):
latent_channels = 16 latent_channels = 16
@ -178,6 +180,7 @@ class Flux(SD3):
class Mochi(LatentFormat): class Mochi(LatentFormat):
latent_channels = 12 latent_channels = 12
latent_dimensions = 3
def __init__(self): def __init__(self):
self.scale_factor = 1.0 self.scale_factor = 1.0
@ -190,7 +193,21 @@ class Mochi(LatentFormat):
0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1) 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors = None #TODO self.latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
self.taesd_decoder_name = None #TODO self.taesd_decoder_name = None #TODO
def process_in(self, latent): def process_in(self, latent):
@ -202,3 +219,166 @@ class Mochi(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[ 0.0696, 0.0795, 0.0518],
[ 0.0135, -0.0945, -0.0282],
[ 0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[ 0.1166, 0.1627, 0.0962],
[ 0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[ 0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]

View File

@ -2,7 +2,7 @@
import torch import torch
from torch import nn from torch import nn
from typing import Literal, Dict, Any from typing import Literal
import math import math
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -97,7 +97,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
raise ValueError(f"Unknown activation {activation}") raise ValueError(f"Unknown activation {activation}")
if antialias: if antialias:
act = Activation1d(act) act = Activation1d(act) # noqa: F821 Activation1d is not defined
return act return act

View File

@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
def forward(self, t): def forward(self, t):
# device = self.inv_freq.device # device = self.inv_freq.device
device = t.device device = t.device
dtype = t.dtype
# t = t.to(torch.float32) # t = t.to(torch.float32)
@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
if self.scale is None: if self.scale is None:
return freqs, 1. return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1') scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1) scale = torch.cat((scale, scale), dim = -1)
@ -229,9 +228,9 @@ class FeedForward(nn.Module):
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations) linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
else: else:
linear_in = nn.Sequential( linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), rearrange('b n d -> b d n') if use_conv else nn.Identity(),
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device), operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation activation
) )
@ -246,9 +245,9 @@ class FeedForward(nn.Module):
self.ff = nn.Sequential( self.ff = nn.Sequential(
linear_in, linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(), rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out, linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), rearrange('b n d -> b d n') if use_conv else nn.Identity(),
) )
def forward(self, x): def forward(self, x):
@ -346,18 +345,13 @@ class Attention(nn.Module):
# determine masking # determine masking
masks = [] masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
if input_mask is not None: if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j') input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask) masks.append(~input_mask)
# Other masks will be added here later # Other masks will be added here later
n = q.shape[-2]
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal
@ -612,7 +606,9 @@ class ContinuousTransformer(nn.Module):
return_info = False, return_info = False,
**kwargs **kwargs
): ):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
info = { info = {
"hidden_states": [], "hidden_states": [],
@ -643,9 +639,19 @@ class ContinuousTransformer(nn.Module):
if self.use_sinusoidal_emb or self.use_abs_pos_emb: if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x) x = x + self.pos_emb(x)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers # Iterate over the transformer layers
for layer in self.layers: for i, layer in enumerate(self.layers):
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info: if return_info:
@ -874,7 +880,6 @@ class AudioDiffusionTransformer(nn.Module):
mask=None, mask=None,
return_info=False, return_info=False,
control=None, control=None,
transformer_options={},
**kwargs): **kwargs):
return self._forward( return self._forward(
x, x,

View File

@ -2,8 +2,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor, einsum from torch import Tensor
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union from typing import List, Union
from einops import rearrange from einops import rearrange
import math import math
import comfy.ops import comfy.ops

View File

@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
bsz, seqlen1, _ = c.shape bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c) cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim) cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
@ -382,7 +381,6 @@ class MMDiT(nn.Module):
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1) pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous() self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim self.h_max, self.w_max = target_dim
print("PE extended to", target_dim)
def pe_selection_index_based_on_dim(self, h, w): def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size h_p, w_p = h // self.patch_size, w // self.patch_size
@ -437,7 +435,8 @@ class MMDiT(nn.Module):
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w] pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, **kwargs): def forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE # patchify x, add PE
b, c, h, w = x.shape b, c, h, w = x.shape
@ -458,15 +457,36 @@ class MMDiT(nn.Module):
global_cond = self.t_embedder(t, x.dtype) # B, D global_cond = self.t_embedder(t, x.dtype) # B, D
blocks_replace = patches_replace.get("dit", {})
if len(self.double_layers) > 0: if len(self.double_layers) > 0:
for layer in self.double_layers: for i, layer in enumerate(self.double_layers):
c, x = layer(c, x, global_cond, **kwargs) if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0: if len(self.single_layers) > 0:
c_len = c.size(1) c_len = c.size(1)
cx = torch.cat([c, x], dim=1) cx = torch.cat([c, x], dim=1)
for layer in self.single_layers: for i, layer in enumerate(self.single_layers):
cx = layer(cx, global_cond, **kwargs) if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:] x = cx[:, c_len:]

View File

@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
import torch
import torchvision import torchvision
from torch import nn from torch import nn
from .common import LayerNorm2d_op from .common import LayerNorm2d_op

View File

@ -2,11 +2,14 @@ import torch
import comfy.ops import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect" padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] pad = ()
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) for i in range(img.ndim - 2):
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
return torch.nn.functional.pad(img, pad, mode=padding_mode)
try: try:
rms_norm_torch = torch.nn.functional.rms_norm rms_norm_torch = torch.nn.functional.rms_norm

View File

@ -6,9 +6,7 @@ import math
from torch import Tensor, nn from torch import Tensor, nn
from einops import rearrange, repeat from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer, from .layers import (timestep_embedding)
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux from .model import Flux
import comfy.ldm.common_dit import comfy.ldm.common_dit

View File

@ -114,7 +114,7 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -141,8 +141,9 @@ class DoubleStreamBlock(nn.Module):
nn.GELU(approximate="tanh"), nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -160,12 +161,22 @@ class DoubleStreamBlock(nn.Module):
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention if self.flipped_img_txt:
attn = attention(torch.cat((txt_q, img_q), dim=2), # run actual attention
torch.cat((txt_k, img_k), dim=2), attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe) torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod1.gate * self.img_attn.proj(img_attn)
@ -217,7 +228,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -226,7 +237,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output x += mod.gate * output

View File

@ -1,14 +1,15 @@
import torch import torch
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x return x
@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -4,6 +4,8 @@ from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from .layers import ( from .layers import (
DoubleStreamBlock, DoubleStreamBlock,
@ -14,12 +16,10 @@ from .layers import (
timestep_embedding, timestep_embedding,
) )
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass @dataclass
class FluxParams: class FluxParams:
in_channels: int in_channels: int
out_channels: int
vec_in_dim: int vec_in_dim: int
context_in_dim: int context_in_dim: int
hidden_size: int hidden_size: int
@ -29,6 +29,7 @@ class FluxParams:
depth_single_blocks: int depth_single_blocks: int
axes_dim: list axes_dim: list
theta: int theta: int
patch_size: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
@ -43,8 +44,9 @@ class Flux(nn.Module):
self.dtype = dtype self.dtype = dtype
params = FluxParams(**kwargs) params = FluxParams(**kwargs)
self.params = params self.params = params
self.in_channels = params.in_channels * 2 * 2 self.patch_size = params.patch_size
self.out_channels = self.in_channels self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0: if params.hidden_size % params.num_heads != 0:
raise ValueError( raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@ -95,8 +97,11 @@ class Flux(nn.Module):
timesteps: Tensor, timesteps: Tensor,
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
control=None, control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.") raise ValueError("Input img and txt tensors must have 3 dimensions.")
@ -114,8 +119,32 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe) if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@ -127,7 +156,23 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe) if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@ -141,9 +186,9 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img return img
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs): def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = 2 patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
@ -151,10 +196,10 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

25
comfy/ldm/flux/redux.py Normal file
View File

@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

View File

@ -461,8 +461,6 @@ class AsymmDiTJoint(nn.Module):
pH, pW = H // self.patch_size, W // self.patch_size pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3 assert x.ndim == 3
B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW N = T * pH * pW
@ -494,8 +492,9 @@ class AsymmDiTJoint(nn.Module):
packed_indices: Dict[str, torch.Tensor] = None, packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None, rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None, rope_sin: torch.Tensor = None,
control=None, **kwargs control=None, transformer_options={}, **kwargs
): ):
patches_replace = transformer_options.get("patches_replace", {})
y_feat = context y_feat = context
y_mask = attention_mask y_mask = attention_mask
sigma = timestep sigma = timestep
@ -515,15 +514,32 @@ class AsymmDiTJoint(nn.Module):
) )
del y_mask del y_mask
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
x, y_feat = block( if ("double_block", i) in blocks_replace:
x, def block_wrap(args):
c, out = {}
y_feat, out["img"], out["txt"] = block(
rope_cos=rope_cos, args["img"],
rope_sin=rope_sin, args["vec"],
crop_y=num_tokens, args["txt"],
) # (B, M, D), (B, L, D) rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features. del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)

View File

@ -1,7 +1,7 @@
#original code from https://github.com/genmoai/models under apache 2.0 license #original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI #adapted to ComfyUI
from typing import Optional, Tuple from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -1,13 +1,17 @@
#original code from https://github.com/genmoai/models under apache 2.0 license #original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI #adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from functools import partial
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -158,8 +162,10 @@ class ResBlock(nn.Module):
*, *,
affine: bool = True, affine: bool = True,
attn_block: Optional[nn.Module] = None, attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True, causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -170,23 +176,23 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
PConv3d( PConv3d(
in_channels=channels, in_channels=channels,
out_channels=channels, out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
# causal=causal, causal=causal,
), ),
norm_fn(channels, affine=affine), norm_fn(channels, affine=affine),
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
PConv3d( PConv3d(
in_channels=channels, in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels, out_channels=channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
# causal=causal, causal=causal,
), ),
) )
@ -206,6 +212,81 @@ class ResBlock(nn.Module):
return self.attn_block(x) return self.attn_block(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
if self.qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module): class CausalUpsampleBlock(nn.Module):
def __init__( def __init__(
self, self,
@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
return x return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs): def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code. attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
class DownsampleBlock(nn.Module): class DownsampleBlock(nn.Module):
@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
out_channels=out_channels, out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction), stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate", padding_mode="replicate",
bias=True, bias=block_kwargs["bias"],
) )
) )
@ -382,7 +459,7 @@ class Decoder(nn.Module):
blocks = [] blocks = []
first_block = [ first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer. ] # Input layer.
# First set of blocks preserve channel count. # First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]): for _ in range(num_res_blocks[-1]):
@ -452,11 +529,165 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous() return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.fourier_features = FourierFeatures()
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.fourier_features(x)
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.encoder = None #TODO once the model releases self.encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
self.decoder = Decoder( self.decoder = Decoder(
out_channels=3, out_channels=3,
base_channels=128, base_channels=128,
@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
) )
def encode(self, x): def encode(self, x):
return self.encoder(x) return self.encoder(x).mode()
def decode(self, x): def decode(self, x):
return self.decoder(x) return self.decoder(x)

View File

@ -0,0 +1,330 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
from torch import Tensor, nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding
)
import comfy.ldm.common_dit
@dataclass
class HunyuanVideoParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: list
qkv_bias: bool
guidance_embed: bool
class SelfAttentionRef(nn.Module):
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
class TokenRefinerBlock(nn.Module):
def __init__(
self,
hidden_size,
heads,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.heads = heads
mlp_hidden_dim = hidden_size * 4
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
class IndividualTokenRefiner(nn.Module):
def __init__(
self,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.blocks = nn.ModuleList(
[
TokenRefinerBlock(
hidden_size=hidden_size,
heads=heads,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_blocks)
]
)
def forward(self, x, c, mask):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
return x
class TokenRefiner(nn.Module):
def __init__(
self,
text_dim,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
def forward(
self,
x,
timesteps,
mask,
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((img, txt), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out

View File

@ -1,24 +1,17 @@
from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import ( from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder, TimestepEmbedder,
PatchEmbed, PatchEmbed,
RMSNorm,
) )
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool from .poolers import AttentionPool
import comfy.latent_formats import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
class HunYuanControlNet(nn.Module): class HunYuanControlNet(nn.Module):
@ -171,9 +164,6 @@ class HunYuanControlNet(nn.Module):
), ),
) )
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks # HUnYuanDiT Blocks
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [

View File

@ -1,8 +1,6 @@
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import comfy.ops import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
@ -250,9 +248,6 @@ class HunYuanDiT(nn.Module):
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
) )
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks # HUnYuanDiT Blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size, HunYuanDiTBlock(hidden_size=hidden_size,
@ -287,7 +282,7 @@ class HunYuanDiT(nn.Module):
style=None, style=None,
return_dict=False, return_dict=False,
control=None, control=None,
transformer_options=None, transformer_options={},
): ):
""" """
Forward pass of the encoder. Forward pass of the encoder.
@ -315,8 +310,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool return_dict: bool
Whether to return a dictionary. Whether to return a dictionary.
""" """
#import pdb patches_replace = transformer_options.get("patches_replace", {})
#pdb.set_trace()
encoder_hidden_states = context encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024 text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@ -364,6 +358,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors # Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None controls = None
if control: if control:
controls = control.get("output", None) controls = control.get("output", None)
@ -375,9 +371,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype) skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else: else:
skip = skips.pop() skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else: else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D) skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1): if layer < (self.depth // 2 - 1):
skips.append(x) skips.append(x)

View File

@ -1,6 +1,5 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ops import comfy.ops

View File

@ -0,0 +1,527 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
else:
self.cond_proj = None
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
if post_act_fn is None:
self.post_act = None
# else:
# self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
def forward(self, x):
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
fractional_positions = get_fractional_positions(indices_grid, max_pos)
start = 1
end = theta
device = fractional_positions.device
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for d in range(num_layers)
]
)
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Patchifier(ABC):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
@abstractmethod
def patchify(
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
) -> Tuple[Tensor, Tensor]:
pass
@abstractmethod
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
pass
@property
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier):
def patchify(
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents

View File

@ -0,0 +1,64 @@
from typing import Tuple, Union
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = ops.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight

View File

@ -0,0 +1,907 @@
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(
torch.tensor(1000.0, dtype=torch.float32)
)
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=sample.shape[0],
hidden_dtype=sample.dtype,
)
embedded_timestep = embedded_timestep.view(
batch_size, embedded_timestep.shape[-1], 1, 1, 1
)
ada_values = self.last_scale_shift_table[
None, ..., None, None, None
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
batch_size,
2,
-1,
embedded_timestep.shape[-3],
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
in_channels * 4, 0, operations=ops,
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
batch_size = hidden_states.shape[0]
timestep_embed = self.time_embedder(
timestep=timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
timestep_embed = timestep_embed.view(
batch_size, timestep_embed.shape[-1], 1, 1, 1
)
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
return hidden_states
class DepthToSpaceUpsample(nn.Module):
def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
):
super().__init__()
self.stride = stride
self.out_channels = (
math.prod(stride) * in_channels // out_channels_reduction_factor
)
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if inject_noise:
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
spatial_shape = hidden_states.shape[-2:]
device = hidden_states.device
dtype = hidden_states.dtype
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = input_tensor
batch_size = hidden_states.shape[0]
hidden_states = self.norm1(hidden_states)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[
None, ..., None, None, None
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
batch_size,
4,
-1,
timestep.shape[-3],
timestep.shape[-2],
timestep.shape[-1],
)
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
hidden_states = hidden_states * (1 + scale1) + shift1
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
)
hidden_states = self.norm2(hidden_states)
if self.timestep_conditioning:
hidden_states = hidden_states * (1 + scale2) + shift2
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
)
input_tensor = self.norm3(input_tensor)
batch_size = input_tensor.shape[0]
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self, version=0):
super().__init__()
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 6, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 7, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 8, "inject_noise": False}]
],
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_all", {}],
["res_x_y", 1],
["res_x", {"num_layers": 3}],
["compress_all", {}],
["res_x_y", 1],
["res_x", {"num_layers": 3}],
["compress_all", {}],
["res_x", {"num_layers": 3}],
["res_x", {"num_layers": 4}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)

View File

@ -0,0 +1,82 @@
from typing import Tuple, Union
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
return ops.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -0,0 +1,195 @@
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."

View File

@ -0,0 +1,12 @@
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)

View File

@ -1,10 +1,12 @@
import logging
import math
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops import comfy.ops
@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any: def get_input(self, batch) -> Any:
raise NotImplementedError() raise NotImplementedError()
@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
logpy.info(f"{context}: Switched to EMA weights") logging.info(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
logpy.info(f"{context}: Restored training weights") logging.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor: def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called") raise NotImplementedError("encode()-method of abstract base class called")
@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called") raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])( return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict()) params, lr=lr, **cfg.get("params", dict())
) )
@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config( self.regularization = instantiate_from_config(
regularizer_config regularizer_config
) )
@ -160,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
}, },
**kwargs, **kwargs,
) )
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
else:
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim, (1 + ddconfig["double_z"]) * embed_dim,
1, 1,
) )
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:

View File

@ -15,6 +15,9 @@ if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
if model_management.sage_attention_enabled():
from sageattention import sageattn
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -157,8 +160,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5
if skip_reshape: if skip_reshape:
query = query.reshape(b * heads, -1, dim_head) query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head)
@ -177,9 +178,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
bytes_per_token = torch.finfo(query.dtype).bits//8 bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape _, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) mem_free_total, _ = model_management.get_free_memory(query.device, True)
kv_chunk_size_min = None kv_chunk_size_min = None
kv_chunk_size = None kv_chunk_size = None
@ -230,7 +230,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
scale = dim_head ** -0.5 scale = dim_head ** -0.5
h = heads
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head), lambda t: t.reshape(b * heads, -1, dim_head),
@ -299,7 +298,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if len(mask.shape) == 2: if len(mask.shape) == 2:
s1 += mask[i:end] s1 += mask[i:end]
else: else:
s1 += mask[:, i:end] if mask.shape[1] == 1:
s1 += mask
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
del s1 del s1
@ -341,12 +343,9 @@ except:
pass pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape: b = q.shape[0]
b, _, _, dim_head = q.shape dim_head = q.shape[-1]
else: # check to make sure xformers isn't broken
b, _, dim_head = q.shape
dim_head //= heads
disabled_xformers = False disabled_xformers = False
if BROKEN_XFORMERS: if BROKEN_XFORMERS:
@ -361,38 +360,54 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape: if skip_reshape:
q, k, v = map( # b h k d -> b k h d
lambda t: t.reshape(b * heads, -1, dim_head), q, k, v = map(
lambda t: t.permute(0, 2, 1, 3),
(q, k, v), (q, k, v),
) )
# actually do the reshaping
else: else:
dim_head //= heads
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head), lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v), (q, k, v),
) )
if mask is not None: if mask is not None:
pad = 8 - q.shape[1] % 8 # add a singleton batch dimension
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) if mask.ndim == 2:
mask_out[:, :, :mask.shape[-1]] = mask mask = mask.unsqueeze(0)
mask = mask_out[:, :, :mask.shape[-1]] # add a singleton heads dimension
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# pad to a multiple of 8
pad = 8 - mask.shape[-1] % 8
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
# in flux, this matrix ends up being over 1GB
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
# doesn't this remove the padding again??
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_reshape: out = (
out = ( out.reshape(b, -1, heads * dim_head)
out.unsqueeze(0) )
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
@ -404,27 +419,85 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v), (q, k, v),
) )
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if mask is not None:
out = ( # add a batch dimension if there isn't already one
out.transpose(1, 2).reshape(b, -1, heads * dim_head) if mask.ndim == 2:
) mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
m = mask
if mask is not None:
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
return out return out
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.xformers_enabled(): if model_management.sage_attention_enabled():
logging.info("Using xformers cross attention") logging.info("Using sage attention")
optimized_attention = attention_sage
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch cross attention") logging.info("Using pytorch attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch
else: else:
if args.use_split_cross_attention: if args.use_split_cross_attention:
logging.info("Using split optimization for cross attention") logging.info("Using split optimization for attention")
optimized_attention = attention_split optimized_attention = attention_split
else: else:
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention

View File

@ -1,5 +1,4 @@
import logging from functools import partial
import math
from typing import Dict, Optional, List from typing import Dict, Optional, List
import numpy as np import numpy as np
@ -72,45 +71,33 @@ class PatchEmbed(nn.Module):
strict_img_size: bool = True, strict_img_size: bool = True,
dynamic_img_pad: bool = True, dynamic_img_pad: bool = True,
padding_mode='circular', padding_mode='circular',
conv3d=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.patch_size = (patch_size, patch_size) try:
len(patch_size)
self.patch_size = patch_size
except:
if conv3d:
self.patch_size = (patch_size, patch_size, patch_size)
else:
self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode self.padding_mode = padding_mode
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat # flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten self.flatten = flatten
self.strict_img_size = strict_img_size self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad self.dynamic_img_pad = dynamic_img_pad
if conv3d:
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
else:
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x): def forward(self, x):
# B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
# elif not self.dynamic_img_pad:
# _assert(
# H % self.patch_size[0] == 0,
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
# )
# _assert(
# W % self.patch_size[1] == 0,
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad: if self.dynamic_img_pad:
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
x = self.proj(x) x = self.proj(x)

View File

@ -3,7 +3,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from typing import Optional, Any
import logging import logging
from comfy import model_management from comfy import model_management
@ -44,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.padding_mode = padding_mode
if padding != 0:
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
else:
kwargs["padding"] = padding
self.padding = padding
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
if self.padding != 0:
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_channels, with_conv): def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
self.scale_factor = scale_factor
if self.with_conv: if self.with_conv:
self.conv = ops.Conv2d(in_channels, self.conv = conv_op(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
def forward(self, x): def forward(self, x):
try: scale_factor = self.scale_factor
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if isinstance(scale_factor, (int, float)):
except: #operation not implemented for bf16 scale_factor = (scale_factor,) * (x.ndim - 2)
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else:
x = interpolate_up(x, scale_factor)
if self.with_conv: if self.with_conv:
x = self.conv(x) x = self.conv(x)
return x return x
class Downsample(nn.Module): class Downsample(nn.Module):
def __init__(self, in_channels, with_conv): def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels, self.conv = conv_op(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=2, stride=stride,
padding=0) padding=0)
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
pad = (0,1,0,1) if x.ndim == 4:
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -97,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512): dropout, temb_channels=512, conv_op=ops.Conv2d):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
@ -106,7 +154,7 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True) self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels) self.norm1 = Normalize(in_channels)
self.conv1 = ops.Conv2d(in_channels, self.conv1 = conv_op(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -116,20 +164,20 @@ class ResnetBlock(nn.Module):
out_channels) out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = ops.Conv2d(out_channels, self.conv2 = conv_op(out_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv2d(in_channels, self.conv_shortcut = conv_op(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
else: else:
self.nin_shortcut = ops.Conv2d(in_channels, self.nin_shortcut = conv_op(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -163,7 +211,6 @@ def slice_attention(q, k, v):
mem_free_total = model_management.get_free_memory(q.device) mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5 modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
@ -196,21 +243,25 @@ def slice_attention(q, k, v):
def normal_attention(q, k, v): def normal_attention(q, k, v):
# compute attention # compute attention
b,c,h,w = q.shape orig_shape = q.shape
b = orig_shape[0]
c = orig_shape[1]
q = q.reshape(b,c,h*w) q = q.reshape(b, c, -1)
q = q.permute(0,2,1) # b,hw,c q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw k = k.reshape(b, c, -1) # b,c,hw
v = v.reshape(b,c,h*w) v = v.reshape(b, c, -1)
r1 = slice_attention(q, k, v) r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w) h_ = r1.reshape(orig_shape)
del r1 del r1
return h_ return h_
def xformers_attention(q, k, v): def xformers_attention(q, k, v):
# compute attention # compute attention
B, C, H, W = q.shape orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map( q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v), (q, k, v),
@ -218,14 +269,16 @@ def xformers_attention(q, k, v):
try: try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W) out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError as e: except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out return out
def pytorch_attention(q, k, v): def pytorch_attention(q, k, v):
# compute attention # compute attention
B, C, H, W = q.shape orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map( q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v), (q, k, v),
@ -233,35 +286,35 @@ def pytorch_attention(q, k, v):
try: try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W) out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out return out
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels, self.q = conv_op(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = ops.Conv2d(in_channels, self.k = conv_op(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = ops.Conv2d(in_channels, self.v = conv_op(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = ops.Conv2d(in_channels, self.proj_out = conv_op(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -291,8 +344,8 @@ class AttnBlock(nn.Module):
return x+h_ return x+h_
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
return AttnBlock(in_channels) return AttnBlock(in_channels, conv_op=conv_op)
class Model(nn.Module): class Model(nn.Module):
@ -451,6 +504,7 @@ class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
conv3d=False, time_compress=None,
**ignore_kwargs): **ignore_kwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn: attn_type = "linear"
@ -461,8 +515,15 @@ class Encoder(nn.Module):
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
if conv3d:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# downsampling # downsampling
self.conv_in = ops.Conv2d(in_channels, self.conv_in = conv_op(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -481,15 +542,20 @@ class Encoder(nn.Module):
block.append(ResnetBlock(in_channels=block_in, block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
conv_op=conv_op))
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
down = nn.Module() down = nn.Module()
down.block = block down.block = block
down.attn = attn down.attn = attn
if i_level != self.num_resolutions-1: if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv) stride = 2
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
@ -498,16 +564,18 @@ class Encoder(nn.Module):
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) conv_op=conv_op)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
conv_op=conv_op)
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = ops.Conv2d(block_in, self.conv_out = conv_op(block_in,
2*z_channels if double_z else z_channels, 2*z_channels if double_z else z_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -545,9 +613,10 @@ class Decoder(nn.Module):
conv_out_op=ops.Conv2d, conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock, resnet_op=ResnetBlock,
attn_op=AttnBlock, attn_op=AttnBlock,
conv3d=False,
time_compress=None,
**ignorekwargs): **ignorekwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch self.ch = ch
self.temb_ch = 0 self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -557,8 +626,15 @@ class Decoder(nn.Module):
self.give_pre_end = give_pre_end self.give_pre_end = give_pre_end
self.tanh_out = tanh_out self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res if conv3d:
in_ch_mult = (1,)+tuple(ch_mult) conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# compute block_in and curr_res at lowest res
block_in = ch*ch_mult[self.num_resolutions-1] block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1) curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res) self.z_shape = (1,z_channels,curr_res,curr_res)
@ -566,7 +642,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in
self.conv_in = ops.Conv2d(z_channels, self.conv_in = conv_op(z_channels,
block_in, block_in,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -577,12 +653,14 @@ class Decoder(nn.Module):
self.mid.block_1 = resnet_op(in_channels=block_in, self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
self.mid.attn_1 = attn_op(block_in) conv_op=conv_op)
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
self.mid.block_2 = resnet_op(in_channels=block_in, self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout,
conv_op=conv_op)
# upsampling # upsampling
self.up = nn.ModuleList() self.up = nn.ModuleList()
@ -594,15 +672,21 @@ class Decoder(nn.Module):
block.append(resnet_op(in_channels=block_in, block.append(resnet_op(in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout,
conv_op=conv_op))
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(attn_op(block_in)) attn.append(attn_op(block_in, conv_op=conv_op))
up = nn.Module() up = nn.Module()
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv) scale_factor = 2.0
if time_compress is not None:
if i_level > math.log2(time_compress):
scale_factor = (1.0, 2.0, 2.0)
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
curr_res = curr_res * 2 curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order self.up.insert(0, up) # prepend to get consistent order

View File

@ -9,12 +9,12 @@ import logging
from .util import ( from .util import (
checkpoint, checkpoint,
avg_pool_nd, avg_pool_nd,
zero_module,
timestep_embedding, timestep_embedding,
AlphaBlender, AlphaBlender,
) )
from ..attention import SpatialTransformer, SpatialVideoTransformer, default from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.patcher_extension
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -47,6 +47,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample): elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape) x = layer(x, output_shape=output_shape)
else: else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x) x = layer(x)
return x return x
@ -819,6 +828,13 @@ class UNetModel(nn.Module):
) )
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.

View File

@ -4,7 +4,6 @@ import numpy as np
from functools import partial from functools import partial
from .util import extract_into_tensor, make_beta_schedule from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module): class AbstractLowScaleModel(nn.Module):

View File

@ -8,8 +8,8 @@
# thanks! # thanks!
import os
import math import math
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
@ -131,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1 steps_out = ddim_timesteps + 1
if verbose: if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}') logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out return steps_out
@ -143,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose: if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, ' logging.info(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}') f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev

View File

@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar) self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar) self.var = torch.exp(self.logvar)
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self): def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x return x
def kl(self, other=None): def kl(self, other=None):

View File

@ -17,12 +17,11 @@ import math
import logging import logging
try: try:
from typing import Optional, NamedTuple, List, Protocol from typing import Optional, NamedTuple, List, Protocol
except ImportError: except ImportError:
from typing import Optional, NamedTuple, List from typing import Optional, NamedTuple, List
from typing_extensions import Protocol from typing_extensions import Protocol
from torch import Tensor
from typing import List from typing import List
from comfy import model_management from comfy import model_management
@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
del attn_scores del attn_scores
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True) summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed attn_scores /= summed
@ -234,6 +233,8 @@ def efficient_dot_product_attention(
def get_mask_chunk(chunk_idx: int) -> Tensor: def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None: if mask is None:
return None return None
if mask.shape[1] == 1:
return mask
chunk = min(query_chunk_size, q_tokens) chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk] return mask[:,chunk_idx:chunk_idx + chunk]

View File

@ -1,5 +1,5 @@
import functools import functools
from typing import Callable, Iterable, Union from typing import Iterable, Union
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
@ -194,6 +194,7 @@ def make_time_attn(
attn_kwargs=None, attn_kwargs=None,
alpha: float = 0, alpha: float = 0,
merge_strategy: str = "learned", merge_strategy: str = "learned",
conv_op=ops.Conv2d,
): ):
return partialclass( return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy

380
comfy/ldm/pixart/blocks.py Normal file
View File

@ -0,0 +1,380 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
# if model_management.xformers_enabled():
# import xformers.ops
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
# else:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
assert mask is None # TODO?
# # TODO: xformers needs separate mask logic here
# if model_management.xformers_enabled():
# attn_bias = None
# if mask is not None:
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
# else:
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
# attn_mask = None
# mask = torch.ones(())
# if mask is not None and len(mask) > 1:
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
# # But depth doesn't matter as tensors can expand in that dimension
# attn_mask_template = torch.ones(
# [q.shape[2] // B, mask[0]],
# dtype=torch.bool,
# device=q.device
# )
# attn_mask = torch.block_diag(attn_mask_template)
#
# # create a mask on the diagonal for each mask in the batch
# for _ in range(B - 1):
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionKVCompress(nn.Module):
"""Multi-head Attention block with KV token compression and qk norm."""
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
self.sr_ratio = sr_ratio
if sr_ratio > 1 and sampling == 'conv':
# Avg Conv Init.
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
# self.sr.weight.data.fill_(1/sr_ratio**2)
# self.sr.bias.data.zero_()
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
if qk_norm:
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == 'uniform_every':
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == 'ave':
tensor = F.interpolate(
tensor, scale_factor=1 / scale_factor, mode='nearest'
).permute(0, 2, 3, 1)
elif sampling == 'uniform':
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == 'conv':
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x, mask=None, HW=None, block_id=None):
B, N, C = x.shape # 2 4096 1152
new_N = N
if HW is None:
H = W = int(N ** 0.5)
else:
H, W = HW
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
q = self.q_norm(q)
k = self.k_norm(k)
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
if mask is not None:
raise NotImplementedError("Attn mask logic not added for self attention")
# This is never called at the moment
# attn_bias = None
# if mask is not None:
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
# attention 2
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
x = x.view(B, N, C)
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.out_channels = out_channels
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MaskFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs//s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
self.uncond_prob = uncond_prob
def token_drop(self, global_caption, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return global_caption, caption
def forward(self, caption, train, force_drop_ids=None):
assert caption.shape[2: ] == self.y_embedding.shape
global_caption = caption.mean(dim=2).squeeze()
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
y_embed = self.proj(global_caption)
return y_embed, caption

View File

@ -0,0 +1,256 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@ -1,4 +1,5 @@
import importlib import importlib
import logging
import torch import torch
from torch import optim from torch import optim
@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
try: try:
draw.text((0, 0), lines, fill="black", font=font) draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError: except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.") logging.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt) txts.append(txt)
@ -65,7 +66,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False): def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
if verbose: if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params return total_params
@ -133,7 +134,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
exp_avgs = [] exp_avgs = []
exp_avg_sqs = [] exp_avg_sqs = []
ema_params_with_grad = [] ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = [] max_exp_avg_sqs = []
state_steps = [] state_steps = []
amsgrad = group['amsgrad'] amsgrad = group['amsgrad']

View File

@ -33,7 +33,7 @@ LORA_CLIP_MAP = {
} }
def load_lora(lora, to_load): def load_lora(lora, to_load, log_missing=True):
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -49,10 +49,20 @@ def load_lora(lora, to_load):
dora_scale = lora[dora_scale_name] dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name) loaded_keys.add(dora_scale_name)
reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
regular_lora = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x) diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x) diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None A_name = None
@ -72,6 +82,10 @@ def load_lora(lora, to_load):
A_name = diffusers3_lora A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x) B_name = "{}.lora.down.weight".format(x)
mid_name = None mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys(): elif transformers_lora in lora.keys():
A_name = transformers_lora A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x) B_name ="{}.lora_linear_layer.down.weight".format(x)
@ -82,7 +96,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -193,9 +207,16 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name) loaded_keys.add(diff_bias_name)
for x in lora.keys(): set_weight_name = "{}.set_weight".format(x)
if x not in loaded_keys: set_weight = lora.get(set_weight_name, None)
logging.warning("lora key not loaded: {}".format(x)) if set_weight is not None:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
if log_missing:
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return patch_dict return patch_dict
@ -282,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
sdk = sd.keys() sdk = sd.keys()
for k in sdk: for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): if k.startswith("diffusion_model."):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") if k.endswith(".weight"):
key_map["lora_unet_{}".format(key_lora)] = k key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config key_map["lora_unet_{}".format(key_lora)] = k
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys: for k in diffusers_keys:
@ -320,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys: for k in diffusers_keys:
@ -329,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to key_map[key_lora] = to
if isinstance(model, comfy.model_base.PixArt):
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
key_map[key_lora] = to
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
key_map[key_lora] = to
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
key_map[key_lora] = to
if isinstance(model, comfy.model_base.HunyuanDiT): if isinstance(model, comfy.model_base.HunyuanDiT):
for k in sdk: for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): if k.startswith("diffusion_model.") and k.endswith(".weight"):
@ -344,6 +381,24 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.HunyuanVideo):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
# diffusion-pipe lora format
key_lora = k
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
return key_map return key_map
@ -400,7 +455,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches: for p in patches:
strength = p[0] strength = p[0]
v = p[1] v = p[1]
@ -440,10 +495,22 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else: else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "model_as_lora":
target_weight: torch.Tensor = v[0]
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4] dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None: if v[2] is not None:
alpha = v[2] / mat2.shape[0] alpha = v[2] / mat2.shape[0]
else: else:

17
comfy/lora_convert.py Normal file
View File

@ -0,0 +1,17 @@
import torch
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
return sd

View File

@ -26,18 +26,25 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.genmo.joint_model.asymm_models_joint import comfy.ldm.genmo.joint_model.asymm_models_joint
import comfy.ldm.aura.mmdit import comfy.ldm.aura.mmdit
import comfy.ldm.pixart.pixartms
import comfy.ldm.hydit.models import comfy.ldm.hydit.models
import comfy.ldm.audio.dit import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders import comfy.ldm.audio.embedders
import comfy.ldm.flux.model import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension
import comfy.conds import comfy.conds
import comfy.ops import comfy.ops
from enum import Enum from enum import Enum
from . import utils from . import utils
import comfy.latent_formats import comfy.latent_formats
import math import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -94,6 +101,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device self.device = device
self.current_patcher: 'ModelPatcher' = None
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
@ -119,6 +127,13 @@ class BaseModel(torch.nn.Module):
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._apply_model,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
xc = self.model_sampling.calculate_input(sigma, x) xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:
@ -153,8 +168,7 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return None return None
def extra_conds(self, **kwargs): def concat_cond(self, **kwargs):
out = {}
if len(self.concat_keys) > 0: if len(self.concat_keys) > 0:
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
@ -193,7 +207,14 @@ class BaseModel(torch.nn.Module):
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data) return data
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
if concat_cond is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
@ -408,7 +429,6 @@ class SVD_img2vid(BaseModel):
latent_image = kwargs.get("concat_latent_image", None) latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None: if latent_image is None:
latent_image = torch.zeros_like(noise) latent_image = torch.zeros_like(noise)
@ -523,9 +543,7 @@ class SD_X4Upscaler(BaseModel):
return out return out
class IP2P: class IP2P:
def extra_conds(self, **kwargs): def concat_cond(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None) image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
@ -537,18 +555,15 @@ class IP2P:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0]) image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel): class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL): class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -673,6 +688,7 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l] sd["{}{}".format(k, l)] = s[l]
return sd return sd
class HunyuanDiT(BaseModel): class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
@ -697,18 +713,72 @@ class HunyuanDiT(BaseModel):
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width) target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height) target_height = kwargs.get("target_height", height)
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
width = kwargs.get("width", None)
height = kwargs.get("height", None)
if width is not None and height is not None:
out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]]))
out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]]))
return out
class Flux(BaseModel): class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def concat_cond(self, **kwargs):
try:
#Handle Flux control loras dynamically changing the img_in weight.
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
except:
#Some cases like tensorrt might not have the weights accessible
num_channels = self.model_config.unet_config["in_channels"]
out_channels = self.model_config.unet_config["out_channels"]
if num_channels <= out_channels:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
image = self.process_latent_in(image)
if num_channels <= out_channels * 2:
return image
#inpaint model
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return kwargs["pooled_output"] return kwargs["pooled_output"]
@ -717,6 +787,16 @@ class Flux(BaseModel):
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out return out
@ -734,3 +814,45 @@ class GenmoMochi(BaseModel):
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
return out
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
return out

View File

@ -133,10 +133,36 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["image_model"] = "hydit1" unet_config["image_model"] = "hydit1"
return unet_config return unet_config
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = 16
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
patch_size = 2
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768 dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096 dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072 dit_config["hidden_size"] = 3072
@ -177,6 +203,41 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0 dit_config["rope_theta"] = 10000.0
return dit_config return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
# PixArt diffusers
return None
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}
dit_config["num_heads"] = 16
dit_config["patch_size"] = patch_size
dit_config["hidden_size"] = 1152
dit_config["in_channels"] = 4
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
dit_config["model_max_length"] = state_dict[y_key].shape[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
if ar_key in state_dict_keys:
dit_config["image_model"] = "pixart_alpha"
dit_config["micro_condition"] = True
else:
dit_config["image_model"] = "pixart_sigma"
dit_config["micro_condition"] = False
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None
@ -206,7 +267,6 @@ def detect_unet_config(state_dict, key_prefix):
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
attention_resolutions = []
transformer_depth = [] transformer_depth = []
transformer_depth_output = [] transformer_depth_output = []
context_dim = None context_dim = None
@ -321,8 +381,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config) model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None) scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_weight is not None: if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32: if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn model_config.scaled_fp8 = torch.float8_e4m3fn
@ -377,7 +438,6 @@ def convert_config(unet_config):
t_out += [d] * (res + 1) t_out += [d] * (res + 1)
s *= 2 s *= 2
transformer_depth = t_in transformer_depth = t_in
transformer_depth_output = t_out
new_config["transformer_depth"] = t_in new_config["transformer_depth"] = t_in
new_config["transformer_depth_output"] = t_out new_config["transformer_depth_output"] = t_out
new_config["transformer_depth_middle"] = transformer_depth_middle new_config["transformer_depth_middle"] = transformer_depth_middle
@ -540,7 +600,14 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {} out_sd = {}
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0] hidden_size = state_dict["x_embedder.bias"].shape[0]
@ -549,10 +616,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else: else:
return None return None

View File

@ -23,6 +23,8 @@ from comfy.cli_args import args
import torch import torch
import sys import sys
import platform import platform
import weakref
import gc
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -73,7 +75,7 @@ if args.directml is not None:
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
_ = torch.xpu.device_count() _ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available() xpu_available = xpu_available or torch.xpu.is_available()
except: except:
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
@ -84,6 +86,13 @@ try:
except: except:
pass pass
try:
import torch_npu # noqa: F401
_ = torch.npu.device_count()
npu_available = torch.npu.is_available()
except:
npu_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -95,6 +104,12 @@ def is_intel_xpu():
return True return True
return False return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -108,6 +123,8 @@ def get_torch_device():
else: else:
if is_intel_xpu(): if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device()) return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
@ -128,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -186,38 +209,44 @@ def is_nvidia():
return True return True
return False return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
if torch.version.hip:
return True
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.2
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention: if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
VAE_DTYPES = [torch.float32]
try: try:
if is_nvidia(): if is_nvidia():
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: if is_intel_xpu() or is_ascend_npu():
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
pass pass
if is_intel_xpu():
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if args.cpu_vae:
VAE_DTYPES = [torch.float32]
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
if args.lowvram: if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True lowvram_available = True
@ -266,6 +295,8 @@ def get_torch_device_name(device):
return "{}".format(device.type) return "{}".format(device.type)
elif is_intel_xpu(): elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device)) return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -287,15 +318,34 @@ def module_size(module):
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self.model = model self._set_model(model)
self.device = model.load_device self.device = model.load_device
self.weights_loaded = False
self.real_model = None self.real_model = None
self.currently_used = True self.currently_used = True
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
def _switch_parent(self):
model = self._parent_model()
if model is not None:
self._set_model(model)
@property
def model(self):
return self._model()
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_loaded_memory(self):
return self.model.loaded_size()
def model_offloaded_memory(self): def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size() return self.model.model_size() - self.model.loaded_size()
@ -306,32 +356,23 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0, force_patch_weights=False): def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device) self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype()) self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded # if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if self.model.loaded_size() > 0: if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram)
else:
try:
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad(): with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True self.real_model = weakref.ref(real_model)
return self.real_model self.model_finalizer = weakref.finalize(real_model, cleanup_models)
return real_model
def should_reload_model(self, force_patch_weights=False): def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter() > 0: if force_patch_weights and self.model.lowvram_patch_counter() > 0:
@ -344,18 +385,26 @@ class LoadedModel:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free) freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free: if freed >= memory_to_free:
return False return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.detach(unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model_finalizer.detach()
self.weights_loaded = self.weights_loaded and not unpatch_weights self.model_finalizer = None
self.real_model = None self.real_model = None
return True return True
def model_use_more_vram(self, extra_memory): def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory) return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
def __del__(self):
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def is_dead(self):
return self.real_model() is not None and self.model is None
def use_more_memory(extra_memory, loaded_models, device): def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models: for m in loaded_models:
if m.device == device: if m.device == device:
@ -386,38 +435,8 @@ def extra_reserved_memory():
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return True
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
@ -425,7 +444,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
@ -454,6 +473,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state global vram_state
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
@ -466,11 +486,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models = set(models) models = set(models)
models_to_load = [] models_to_load = []
models_already_loaded = []
for x in models: for x in models:
loaded_model = LoadedModel(x) loaded_model = LoadedModel(x)
loaded = None
try: try:
loaded_model_index = current_loaded_models.index(loaded_model) loaded_model_index = current_loaded_models.index(loaded_model)
except: except:
@ -478,51 +496,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model_index is not None: if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index] loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic loaded.currently_used = True
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) models_to_load.append(loaded)
loaded = None else:
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"): if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}") logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
if len(models_to_load) == 0: for loaded_model in models_to_load:
devs = set(map(lambda a: a.device, models_already_loaded)) to_unload = []
for d in devs: for i in range(len(current_loaded_models)):
if d != torch.device("cpu"): if loaded_model.model.is_clone(current_loaded_models[i].model):
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded) to_unload = [i] + to_unload
free_mem = get_free_memory(d) for i in to_unload:
if free_mem < minimum_memory_required: current_loaded_models.pop(i).model.detach(unpatch_all=False)
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
else:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_already_loaded:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded) free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
logging.info("{} models unloaded.".format(len(models_l)))
for loaded_model in models_to_load: for loaded_model in models_to_load:
model = loaded_model.model model = loaded_model.model
@ -534,27 +536,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
lowvram_model_memory = 0 lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load: if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) loaded_memory = loaded_model.model_loaded_memory()
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024 lowvram_model_memory = 0.1
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
return return
def load_model_gpu(model): def load_model_gpu(model):
return load_models_gpu([model]) return load_models_gpu([model])
@ -568,21 +564,35 @@ def loaded_models(only_currently_used=False):
output.append(m.model) output.append(m.model)
return output return output
def cleanup_models(keep_clone_weights_loaded=False):
def cleanup_models_gc():
do_gc = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
if do_gc:
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def cleanup_models():
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
#TODO: very fragile function needs improvement if current_loaded_models[i].real_model() is None:
num_refs = sys.getrefcount(current_loaded_models[i].model) to_delete = [i] + to_delete
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:
x = current_loaded_models.pop(i) x = current_loaded_models.pop(i)
x.model_unload()
del x del x
def dtype_size(dtype): def dtype_size(dtype):
@ -606,7 +616,7 @@ def unet_offload_device():
def unet_inital_load_device(parameters, dtype): def unet_inital_load_device(parameters, dtype):
torch_dev = get_torch_device() torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev return torch_dev
cpu_dev = torch.device("cpu") cpu_dev = torch.device("cpu")
@ -628,6 +638,10 @@ def maximum_vram_for_weights(device=None):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0: if model_params < 0:
model_params = 1000000000000000000000 model_params = 1000000000000000000000
if args.fp32_unet:
return torch.float32
if args.fp64_unet:
return torch.float64
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet: if args.fp16_unet:
@ -674,7 +688,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
# None means no manual cast # None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32: if weight_dtype == torch.float32 or weight_dtype == torch.float64:
return None return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
@ -716,7 +730,7 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
return offload_device return offload_device
if is_device_mps(load_device): if is_device_mps(load_device):
return offload_device return load_device
mem_l = get_free_memory(load_device) mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device) mem_o = get_free_memory(offload_device)
@ -759,7 +773,6 @@ def vae_offload_device():
return torch.device("cpu") return torch.device("cpu")
def vae_dtype(device=None, allowed_dtypes=[]): def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPES
if args.fp16_vae: if args.fp16_vae:
return torch.float16 return torch.float16
elif args.bf16_vae: elif args.bf16_vae:
@ -768,12 +781,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return torch.float32 return torch.float32
for d in allowed_dtypes: for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): if d == torch.float16 and should_use_fp16(device):
return d
if d in VAE_DTYPES:
return d return d
return VAE_DTYPES[0] # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
return torch.float32
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
@ -858,6 +873,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device) non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
def sage_attention_enabled():
return args.use_sage_attention
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
@ -866,6 +883,8 @@ def xformers_enabled():
return False return False
if is_intel_xpu(): if is_intel_xpu():
return False return False
if is_ascend_npu():
return False
if directml_enabled: if directml_enabled:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -890,16 +909,23 @@ def pytorch_attention_flash_attention():
return True return True
if is_intel_xpu(): if is_intel_xpu():
return True return True
if is_ascend_npu():
return True
return False return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) macos_version = mac_version()
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
upcast = True upcast = True
except:
pass
if upcast: if upcast:
return torch.float32 return torch.float32
else: else:
@ -924,6 +950,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
@ -970,17 +1003,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP16: if FORCE_FP16:
return True return True
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_enabled:
return False return False
if mps_mode(): if (device is not None and is_device_mps(device)) or mps_mode():
return True return True
if cpu_mode(): if cpu_mode():
@ -989,6 +1018,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu(): if is_intel_xpu():
return True return True
if is_ascend_npu():
return True
if torch.version.hip: if torch.version.hip:
return True return True
@ -1029,17 +1061,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False return False
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_enabled:
return False return False
if mps_mode(): if (device is not None and is_device_mps(device)) or mps_mode():
if mac_version() < (14,):
return False
return True return True
if cpu_mode(): if cpu_mode():
@ -1088,19 +1118,16 @@ def soft_empty_cache(force=False):
torch.mps.empty_cache() torch.mps.empty_cache()
elif is_intel_xpu(): elif is_intel_xpu():
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache()
torch.cuda.empty_cache() torch.cuda.ipc_collect()
torch.cuda.ipc_collect()
def unload_all_models(): def unload_all_models():
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
import threading import threading

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math import math
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class EPS: class EPS:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
@ -48,7 +67,7 @@ class CONST:
return latent / (1.0 - sigma) return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None, zsnr=None):
super().__init__() super().__init__()
if model_config is not None: if model_config is not None:
@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
linear_end = sampling_settings.get("linear_end", 0.012) linear_end = sampling_settings.get("linear_end", 0.012)
timesteps = sampling_settings.get("timesteps", 1000) timesteps = sampling_settings.get("timesteps", 1000)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) if zsnr is None:
zsnr = sampling_settings.get("zsnr", False)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
self.sigma_data = 1.0 self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
if given_betas is not None: if given_betas is not None:
betas = given_betas betas = given_betas
else: else:
@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
def set_sigmas(self, sigmas): def set_sigmas(self, sigmas):
@ -218,7 +243,7 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return 1.0 return 1.0
if percent >= 1.0: if percent >= 1.0:
return 0.0 return 0.0
return 1.0 - percent return time_snr_shift(self.shift, 1.0 - percent)
class StableCascadeSampling(ModelSamplingDiscrete): class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None): def __init__(self, model_config=None):
@ -311,4 +336,4 @@ class ModelSamplingFlux(torch.nn.Module):
return 1.0 return 1.0
if percent >= 1.0: if percent >= 1.0:
return 0.0 return 0.0
return 1.0 - percent return flux_time_shift(self.shift, 1.0, 1.0 - percent)

View File

@ -255,9 +255,10 @@ def fp8_linear(self, input):
tensor_2d = True tensor_2d = True
input = input.unsqueeze(1) input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3: if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t() w = w.t()
scale_weight = self.scale_weight scale_weight = self.scale_weight
@ -269,23 +270,24 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
if bias is not None: if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else: else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple): if isinstance(o, tuple):
o = o[0] o = o[0]
if tensor_2d: if tensor_2d:
return o.reshape(input.shape[0], -1) return o.reshape(input_shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0])) return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None return None

156
comfy/patcher_extension.py Normal file
View File

@ -0,0 +1,156 @@
from __future__ import annotations
from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"
ON_PRE_RUN = "on_pre_run"
ON_PREPARE_STATE = "on_prepare_state"
ON_APPLY_HOOKS = "on_apply_hooks"
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
ON_INJECT_MODEL = "on_inject_model"
ON_EJECT_MODEL = "on_eject_model"
# callbacks dict is in the format:
# {"call_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
c_list.extend(callbacks.get(call_type, {}).get(key, []))
return c_list
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
for c in callbacks.get(call_type, {}).values():
c_list.extend(c)
return c_list
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
# wrappers dict is in the format:
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
return w_list
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
for w in wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list
class WrapperExecutor:
"""Handles call stack of wrappers around a function in an ordered manner."""
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
# NOTE: class_obj exists so that wrappers surrounding a class method can access
# the class instance at runtime via executor.class_obj
self.original = original
self.class_obj = class_obj
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, *args, **kwargs)
def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)
class PatcherInjection:
def __init__(self, inject: Callable, eject: Callable):
self.inject = inject
self.eject = eject
def copy_nested_dicts(input_dict: dict):
new_dict = input_dict.copy()
for key, value in input_dict.items():
if isinstance(value, dict):
new_dict[key] = copy_nested_dicts(value)
elif isinstance(value, list):
new_dict[key] = value.copy()
return new_dict
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
if copy_dict1:
merged_dict = copy_nested_dicts(dict1)
else:
merged_dict = dict1
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(value, curr_value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:
merged_dict[key] = value
return merged_dict

View File

@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None):
return noises return noises
def fix_empty_latent_channels(model, latent_image): def fix_empty_latent_channels(model, latent_image):
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1) latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image return latent_image
def prepare_sampling(model, noise_shape, positive, negative, noise_mask): def prepare_sampling(model, noise_shape, positive, negative, noise_mask):

View File

@ -1,22 +1,60 @@
import torch from __future__ import annotations
import uuid
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" return comfy.utils.reshape_mask(noise_mask, shape).to(device)
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type): def get_models_from_cond(cond, model_type):
models = [] models = []
for c in cond: for c in cond:
if model_type in c: if model_type in c:
models += [c[model_type]] if isinstance(c[model_type], list):
models += c[model_type]
else:
models += [c[model_type]]
return models return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
if 'control' in c:
cnets.append(c['control'])
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
if cnet.extra_hooks is not None:
_list.append(cnet.extra_hooks)
if cnet.previous_controlnet is None:
return _list
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
hooks_list = []
cnets = set(cnets)
for base_cnet in cnets:
get_extra_hooks_from_cnet(base_cnet, hooks_list)
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
return hooks_dict
def convert_cond(cond): def convert_cond(cond):
out = [] out = []
for c in cond: for c in cond:
@ -26,17 +64,22 @@ def convert_cond(cond):
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0] temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
out.append(temp) out.append(temp)
return out return out
def get_additional_models(conds, dtype): def get_additional_models(conds, dtype):
"""loads additional models in conditioning""" """loads additional models in conditioning"""
cnets = [] cnets: list[ControlBase] = []
gligen = [] gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds: for k in conds:
cnets += get_models_from_cond(conds[k], "control") cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen") gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets) control_nets = set(cnets)
@ -47,7 +90,9 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype) inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
models = control_models + gligen hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
return models, inference_memory return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
@ -57,10 +102,10 @@ def cleanup_additional_models(models):
m.cleanup() m.cleanup()
def prepare_sampling(model, noise_shape, conds): def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device real_model: 'BaseModel' = None
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
@ -76,3 +121,14 @@ def cleanup_models(conds, models):
control_cleanup += get_models_from_cond(conds[k], "control") control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup)) cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)

View File

@ -1,11 +1,22 @@
from __future__ import annotations
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
import torch import torch
from functools import partial
import collections import collections
from comfy import model_management from comfy import model_management
import math import math
import logging import logging
import comfy.samplers
import comfy.sampler_helpers import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import scipy.stats import scipy.stats
import numpy import numpy
@ -70,6 +81,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
for c in model_conds: for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None) control = conds.get('control', None)
patches = None patches = None
@ -85,8 +97,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
patches['middle_patch'] = [gligen_patch] patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
return cond_obj(input_x, mult, conditioning, area, control, patches) return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
@ -119,11 +131,6 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning) return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {} temp = {}
for x in c_list: for x in c_list:
for k in x: for k in x:
@ -138,110 +145,184 @@ def cond_cat(c_list):
return out return out
def calc_cond_batch(model, conds, x_in, timestep, model_options): def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
# need to figure out remaining unmasked area for conds
default_mults = []
for _ in default_conds:
default_mults.append(torch.ones_like(x_in))
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
for lora_hooks, to_run in hooked_to_run.items():
for cond_obj, i in to_run:
# if no default_cond for cond_type, do nothing
if len(default_conds[i]) == 0:
continue
area: list[int] = cond_obj.area
if area is not None:
curr_default_mult: torch.Tensor = default_mults[i]
dims = len(area) // 2
for i in range(dims):
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
curr_default_mult -= cond_obj.mult
else:
default_mults[i] -= cond_obj.mult
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
for i, mult in enumerate(default_mults):
# if no default_cond for cond type, do nothing
if len(default_conds[i]) == 0:
continue
torch.nn.functional.relu(mult, inplace=True)
# if mult is all zeros, then don't add default_cond
if torch.max(mult) == 0.0:
continue
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
p = p._replace(mult=mult)
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = [] out_conds = []
out_counts = [] out_counts = []
to_run = [] # separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
for i in range(len(conds)): for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in)) out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37) out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i] cond = conds[i]
default_c = []
if cond is not None: if cond is not None:
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in, timestep) if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None: if p is None:
continue continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
to_run += [(p, i)] if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
while len(to_run) > 0: model.current_patcher.prepare_state(timestep)
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse() # run every hooked_to_run separately
to_batch = to_batch_temp[:1] for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
free_memory = model_management.get_free_memory(x_in.device) to_batch_temp.reverse()
for i in range(1, len(to_batch_temp) + 1): to_batch = to_batch_temp[:1]
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
input_x = [] free_memory = model_management.get_free_memory(x_in.device)
mult = [] for i in range(1, len(to_batch_temp) + 1):
c = [] batch_amount = to_batch_temp[:len(to_batch_temp)//i]
cond_or_uncond = [] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
area = [] if model.memory_required(input_shape) * 1.5 < free_memory:
control = None to_batch = batch_amount
patches = None break
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond) input_x = []
input_x = torch.cat(input_x) mult = []
c = cond_cat(c) c = []
timestep_ = torch.cat([timestep] * batch_chunks) cond_or_uncond = []
uuids = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
if control is not None: batch_chunks = len(cond_or_uncond)
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {} transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy() transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None: if patches is not None:
if "patches" in transformer_options: # TODO: replace with merge_nested_dicts function
cur_patches = transformer_options["patches"].copy() if "patches" in transformer_options:
for p in patches: cur_patches = transformer_options["patches"].copy()
if p in cur_patches: for p in patches:
cur_patches[p] = cur_patches[p] + patches[p] if p in cur_patches:
else: cur_patches[p] = cur_patches[p] + patches[p]
cur_patches[p] = patches[p] else:
transformer_options["patches"] = cur_patches cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
transformer_options["patches"] = patches output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
transformer_options["cond_or_uncond"] = cond_or_uncond[:] for o in range(batch_chunks):
transformer_options["sigmas"] = timestep cond_index = cond_or_uncond[o]
a = area[o]
c['transformer_options'] = transformer_options if a is None:
out_conds[cond_index] += output[o] * mult[o]
if 'model_function_wrapper' in model_options: out_counts[cond_index] += mult[o]
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else:
else: out_c = out_conds[cond_index]
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) out_cts = out_counts[cond_index]
dims = len(a) // 2
for o in range(batch_chunks): for i in range(dims):
cond_index = cond_or_uncond[o] out_c = out_c.narrow(i + 2, a[i + dims], a[i])
a = area[o] out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
if a is None: out_c += output[o] * mult[o]
out_conds[cond_index] += output[o] * mult[o] out_cts += mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)): for i in range(len(out_conds)):
out_conds[i] /= out_counts[i] out_conds[i] /= out_counts[i]
@ -261,7 +342,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []): for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x} "sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args) cfg_result = fn(args)
@ -387,6 +468,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
sigma_schedule = [1.0 - x for x in sigma_schedule] sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu() return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
sigmas = adj_idxs.new_zeros(n + 1)
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
return sigmas
def get_mask_aabb(masks): def get_mask_aabb(masks):
if masks.numel() == 0: if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int) return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -500,10 +588,15 @@ def calculate_start_end_timesteps(model, conds):
timestep_start = None timestep_start = None
timestep_end = None timestep_end = None
if 'start_percent' in x: # handle clip hook schedule, if needed
timestep_start = s.percent_to_sigma(x['start_percent']) if 'clip_start_percent' in x:
if 'end_percent' in x: timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
timestep_end = s.percent_to_sigma(x['end_percent']) timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
else:
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None): if (timestep_start is not None) or (timestep_end is not None):
n = x.copy() n = x.copy()
@ -518,8 +611,6 @@ def pre_run_control(model, conds):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
@ -673,6 +764,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if k != kk: if k != kk:
create_cond_with_same_area_if_none(conds[kk], c) create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
for c in conds[k]:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook.initialize_timesteps(model)
for k in conds: for k in conds:
pre_run_control(model, conds[k]) pre_run_control(model, conds[k])
@ -685,9 +782,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds return conds
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
# determine which ControlNets have extra_hooks that should be combined with normal hooks
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
for k in conds:
for kk in conds[k]:
if 'control' in kk:
control: 'ControlBase' = kk['control']
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
to_replace = hook_replacement.setdefault((control, hooks), [])
to_replace.append(kk)
# if nothing to replace, do nothing
if len(hook_replacement) == 0:
return
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
# on the cond dicts
for key, conds_to_modify in hook_replacement.items():
control = key[0]
hooks = key[1]
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
# if combined hooks are not None, set as new hooks for all relevant conds
if hooks is not None:
for cond in conds_to_modify:
cond['hooks'] = hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
for k in conds:
for kk in conds[k]:
hooks_set.add(kk.get('hooks', None))
return len(hooks_set)
class CFGGuider: class CFGGuider:
def __init__(self, model_patcher): def __init__(self, model_patcher):
self.model_patcher = model_patcher self.model_patcher: 'ModelPatcher' = model_patcher
self.model_options = model_patcher.model_options self.model_options = model_patcher.model_options
self.original_conds = {} self.original_conds = {}
self.cfg = 1.0 self.cfg = 1.0
@ -714,19 +848,19 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed} extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
extra_args = {"model_options": extra_model_options, "seed": seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample,
sampler,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32)) return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device device = self.model_patcher.load_device
@ -737,14 +871,48 @@ class CFGGuider:
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model del self.inner_model
del self.conds
del self.loaded_models del self.loaded_models
return output return output
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
preprocess_conds_hooks(self.conds)
try:
orig_model_options = self.model_options
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
orig_hook_mode = self.model_patcher.hook_mode
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
del self.conds
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model) cfg_guider = CFGGuider(model)
@ -753,29 +921,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps): class SchedulerHandler(NamedTuple):
if scheduler_name == "karras": handler: Callable[..., torch.Tensor]
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) # Boolean indicates whether to call the handler like:
elif scheduler_name == "exponential": # scheduler_function(model_sampling, steps) or
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) # scheduler_function(n, sigma_min: float, sigma_max: float)
elif scheduler_name == "normal": use_ms: bool = True
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple": SCHEDULER_HANDLERS = {
sigmas = simple_scheduler(model_sampling, steps) "normal": SchedulerHandler(normal_scheduler),
elif scheduler_name == "ddim_uniform": "karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
sigmas = ddim_scheduler(model_sampling, steps) "exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
elif scheduler_name == "sgm_uniform": "sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
sigmas = normal_scheduler(model_sampling, steps, sgm=True) "simple": SchedulerHandler(simple_scheduler),
elif scheduler_name == "beta": "ddim_uniform": SchedulerHandler(ddim_scheduler),
sigmas = beta_scheduler(model_sampling, steps) "beta": SchedulerHandler(beta_scheduler),
elif scheduler_name == "linear_quadratic": "linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
sigmas = linear_quadratic_schedule(model_sampling, steps) "kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
else: }
logging.error("error invalid scheduler {}".format(scheduler_name)) SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
return sigmas
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
handler = SCHEDULER_HANDLERS.get(scheduler_name)
if handler is None:
err = f"error invalid scheduler {scheduler_name}"
logging.error(err)
raise ValueError(err)
if handler.use_ms:
return handler.handler(model_sampling, steps)
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
def sampler_object(name): def sampler_object(name):
if name == "uni_pc": if name == "uni_pc":

View File

@ -1,14 +1,18 @@
from __future__ import annotations
import torch import torch
from enum import Enum from enum import Enum
import logging import logging
from comfy import model_management from comfy import model_management
from comfy.utils import ProgressBar
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import yaml import yaml
import math
import comfy.utils import comfy.utils
@ -23,16 +27,23 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit import comfy.text_encoders.hydit
import comfy.text_encoders.flux import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.lora_convert
import comfy.hooks
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.taesd.taesd import comfy.taesd.taesd
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = {} key_map = {}
if model is not None: if model is not None:
@ -40,6 +51,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None: if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map) loaded = comfy.lora.load_lora(lora, key_map)
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher = model.clone()
@ -92,10 +104,14 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if params['device'] == load_device: if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True) model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) self.use_clip_schedule = False
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@ -103,6 +119,8 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
@ -114,6 +132,69 @@ class CLIP:
def tokenize(self, text, return_word_ids=False): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
pooled_dict["hooks"] = self.apply_hooks_to_conds
return pooled_dict
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
if all_hooks is None or not self.use_clip_schedule:
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
return_pooled = "unprojected" if unprojected else True
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
cond = pooled_dict.pop("cond")
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
all_cond_pooled.append([cond, pooled_dict])
else:
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options() self.cond_stage_model.reset_clip_options()
@ -131,6 +212,7 @@ class CLIP:
if len(o) > 2: if len(o) > 2:
for k in o[2]: for k in o[2]:
out[k] = o[2][k] out[k] = o[2][k]
self.add_hooks_to_dict(out)
return out return out
if return_pooled: if return_pooled:
@ -171,11 +253,15 @@ class VAE:
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_index_formula = None
self.upscale_index_formula = None
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@ -226,8 +312,8 @@ class VAE:
self.upscale_ratio = 4 self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'quant_conv.weight' in sd: if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else: else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@ -240,16 +326,56 @@ class VAE:
self.output_channels = 2 self.output_channels = 2
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio self.process_output = lambda audio: audio
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd: if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE() self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
self.latent_channels = 12 self.latent_channels = 12
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.upscale_index_formula = (6, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
version = 0
if tensor_conv1.shape[0] == 512:
version = 0
elif tensor_conv1.shape[0] == 1024:
version = 1
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
ddconfig["time_compress"] = 4
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -276,13 +402,15 @@ class VAE:
self.output_device = model_management.intermediate_device() self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1] dims = pixels.shape[1:-1]
for d in range(len(dims)): for d in range(len(dims)):
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % self.downscale_ratio) // 2 x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]: if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x) pixels = pixels.narrow(d + 1, x_offset, x)
return pixels return pixels
@ -303,11 +431,11 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=128, overlap=32): def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
@ -326,6 +454,10 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in): def decode(self, samples_in):
pixel_samples = None pixel_samples = None
try: try:
@ -341,7 +473,7 @@ class VAE:
if pixel_samples is None: if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2 dims = samples_in.ndim - 2
if dims == 1: if dims == 1:
@ -349,49 +481,135 @@ class VAE:
elif dims == 2: elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3: elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in) tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
model_management.load_model_gpu(self.patcher) memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
return output.movedim(1,-1) dims = samples.ndim - 2
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used)) batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device) samples = None
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
if len(pixel_samples.shape) == 3: if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1:
samples = self.encode_tiled_1d(pixel_samples) samples = self.encode_tiled_1d(pixel_samples)
else: else:
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher) dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1, 1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) if dims == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples return samples
def get_sd(self): def get_sd(self):
return self.first_stage_model.state_dict() return self.first_stage_model.state_dict()
def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio
def spacial_compression_encode(self):
try:
return self.downscale_ratio[-1]
except:
return self.downscale_ratio
def temporal_compression_decode(self):
try:
return round(self.upscale_ratio[0](8192) / 8192)
except:
return None
class StyleModel: class StyleModel:
def __init__(self, model, device="cpu"): def __init__(self, model, device="cpu"):
self.model = model self.model = model
@ -405,6 +623,8 @@ def load_style_model(ckpt_path):
keys = model_data.keys() keys = model_data.keys()
if "style_embedding" in keys: if "style_embedding" in keys:
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "redux_down.weight" in keys:
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else: else:
raise Exception("invalid style model {}".format(ckpt_path)) raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data) model.load_state_dict(model_data)
@ -418,6 +638,10 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5 HUNYUAN_DIT = 5
FLUX = 6 FLUX = 6
MOCHI = 7 MOCHI = 7
LTXV = 8
HUNYUAN_VIDEO = 9
PIXART = 10
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
@ -433,6 +657,7 @@ class TEModel(Enum):
T5_XXL = 4 T5_XXL = 4
T5_XL = 5 T5_XL = 5
T5_BASE = 6 T5_BASE = 6
LLAMA3_8 = 7
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -449,6 +674,8 @@ def detect_te_model(sd):
return TEModel.T5_XL return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE return TEModel.T5_BASE
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None return None
@ -461,6 +688,14 @@ def t5xxl_detect(clip_data):
return {} return {}
def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts clip_data = state_dicts
@ -496,6 +731,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data)) clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
else: #CLIPType.MOCHI else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@ -523,6 +764,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.FLUX: elif clip_type == CLIPType.FLUX:
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data)) clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -562,7 +806,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
model_config_params = config['model']['params'] model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config'] clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
if "parameterization" in model_config_params: if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v": if model_config_params["parameterization"] == "v":
@ -732,11 +975,11 @@ def load_diffusion_model(unet_path, model_options={}):
return model return model
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype}) return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None): def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict") logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype}) return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):

View File

@ -10,6 +10,7 @@ import comfy.clip_model
import json import json
import logging import logging
import numbers import numbers
import re
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
@ -36,7 +37,10 @@ class ClipTokenWeightEncoder:
sections = len(to_encode) sections = len(to_encode)
if has_weights or sections == 0: if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) if hasattr(self, "gen_empty_tokens"):
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
else:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
o = self.encode(to_encode) o = self.encode(to_encode)
out, pooled = o[:2] out, pooled = o[:2]
@ -90,8 +94,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
with open(textmodel_json_config) as f: if isinstance(textmodel_json_config, dict):
config = json.load(f) config = textmodel_json_config
else:
with open(textmodel_json_config) as f:
config = json.load(f)
operations = model_options.get("custom_operations", None) operations = model_options.get("custom_operations", None)
scaled_fp8 = None scaled_fp8 = None
@ -196,11 +203,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_mask = None attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens) attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1) end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
for x in range(attention_mask.shape[0]): for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]): for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1 attention_mask[x, y] = 1
if tokens[x, y] == end_token: if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break break
attention_mask_model = None attention_mask_model = None
@ -326,7 +340,6 @@ def expand_directory_list(directories):
return list(dirs) return list(dirs)
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
i = 0
out_list = [] out_list = []
for k in embed: for k in embed:
if k.startswith(prefix) and k.endswith(suffix): if k.startswith(prefix) and k.endswith(suffix):
@ -382,7 +395,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = safe_load_embed_zip(embed_path) embed_out = safe_load_embed_zip(embed_path)
else: else:
embed = torch.load(embed_path, map_location="cpu") embed = torch.load(embed_path, map_location="cpu")
except Exception as e: except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None return None
@ -411,22 +424,31 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length self.max_length = max_length
self.min_length = min_length self.min_length = min_length
self.end_token = None
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
if has_start_token: if has_start_token:
self.tokens_start = 1 self.tokens_start = 1
self.start_token = empty[0] self.start_token = empty[0]
self.end_token = empty[1] if end_token is not None:
self.end_token = end_token
else:
if has_end_token:
self.end_token = empty[1]
else: else:
self.tokens_start = 0 self.tokens_start = 0
self.start_token = None self.start_token = None
self.end_token = empty[0] if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if pad_token is not None: if pad_token is not None:
self.pad_token = pad_token self.pad_token = pad_token
@ -451,13 +473,16 @@ class SDTokenizer:
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
''' '''
split_embed = embedding_name.split()
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None: if embed is None:
stripped = embedding_name.strip(',') stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name): if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):]) return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, "") return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -471,13 +496,18 @@ class SDTokenizer:
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
#tokenize words # tokenize words
tokens = [] tokens = []
for weighted_segment, weight in parsed_weights: for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') to_tokenize = unescape_important(weighted_segment)
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
to_tokenize = [x for x in to_tokenize if x != ""] to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize: for word in to_tokenize:
#if we find an embedding, deal with the embedding # if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n') embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name) embed, leftover = self._try_get_embedding(embedding_name)
@ -493,8 +523,11 @@ class SDTokenizer:
word = leftover word = leftover
else: else:
continue continue
end = 999999999999
if self.tokenizer_adds_end_token:
end = -1
#parse word #parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
#reshape token array to CLIP input size #reshape token array to CLIP input size
batched_tokens = [] batched_tokens = []
@ -505,18 +538,24 @@ class SDTokenizer:
for i, t_group in enumerate(tokens): for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch #determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length is_large = len(t_group) >= self.max_word_length
if self.end_token is not None:
has_end_token = 1
else:
has_end_token = 0
while len(t_group) > 0: while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1: if len(t_group) + len(batch) > self.max_length - has_end_token:
remaining_length = self.max_length - len(batch) - 1 remaining_length = self.max_length - len(batch) - has_end_token
#break word in two and add end token #break word in two and add end token
if is_large: if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0)) if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:] t_group = t_group[remaining_length:]
#add end token and pad #add end token and pad
else: else:
batch.append((self.end_token, 1.0, 0)) if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch #start new batch
@ -529,7 +568,8 @@ class SDTokenizer:
t_group = [] t_group = []
#fill last batch #fill last batch
batch.append((self.end_token, 1.0, 0)) if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length: if self.min_length is not None and len(batch) < self.min_length:

View File

@ -8,9 +8,12 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit import comfy.text_encoders.hydit
import comfy.text_encoders.flux import comfy.text_encoders.flux
import comfy.text_encoders.genmo import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -197,6 +200,8 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item()) self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict: elif "v_pred" in state_dict:
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
self.sampling_settings["zsnr"] = True
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
else: else:
return model_base.ModelType.EPS return model_base.ModelType.EPS
@ -221,7 +226,6 @@ class SDXL(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
for k in state_dict: for k in state_dict:
if k.startswith("clip_l"): if k.startswith("clip_l"):
@ -524,7 +528,6 @@ class SD3(supported_models_base.BASE):
clip_l = False clip_l = False
clip_g = False clip_g = False
t5 = False t5 = False
dtype_t5 = None
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True clip_l = True
@ -590,6 +593,39 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
class PixArtAlpha(supported_models_base.BASE):
unet_config = {
"image_model": "pixart_alpha",
}
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
unet_extra_config = {}
latent_format = latent_formats.SD15
memory_usage_factor = 0.5
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.PixArt(self, device=device)
return out.eval()
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
class PixArtSigma(PixArtAlpha):
unet_config = {
"image_model": "pixart_sigma",
}
latent_format = latent_formats.SDXL
class HunyuanDiT(supported_models_base.BASE): class HunyuanDiT(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "hydit", "image_model": "hydit",
@ -606,6 +642,8 @@ class HunyuanDiT(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
memory_usage_factor = 1.3
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
@ -656,6 +694,15 @@ class Flux(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxInpaint(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
"in_channels": 96,
}
supported_inference_dtypes = [torch.bfloat16, torch.float32]
class FluxSchnell(Flux): class FluxSchnell(Flux):
unet_config = { unet_config = {
"image_model": "flux", "image_model": "flux",
@ -700,7 +747,82 @@ class GenmoMochi(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
class LTXV(supported_models_base.BASE):
unet_config = {
"image_model": "ltxv",
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi] sampling_settings = {
"shift": 2.37,
}
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 2.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
class HunyuanVideo(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan_video",
}
sampling_settings = {
"shift": 7.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo(self, device=device)
return out
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -12,7 +12,7 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel): class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):

View File

@ -0,0 +1,112 @@
from comfy import sd1_clip
import comfy.model_management
import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast
import torch
import os
def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
llama_text = "{}{}".format(self.llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.llama.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_llama = token_weight_pairs["llama"]
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0
for i, v in enumerate(token_weight_pairs_llama[0]):
if v[0] == 128007: # <|end_header_id|>
template_end = i
if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
template_end += 2
llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_out, l_pooled, llama_extra_out
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_

View File

@ -0,0 +1,226 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
import comfy.model_management
@dataclass
class Llama2Config:
vocab_size: int = 128320
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 500000.0
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return (cos, sin)
def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1)
sin = freqs_cis[1].unsqueeze(1)
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
class Attention(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = residual + x
# MLP
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
self.layers = nn.ModuleList([
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embed_tokens(x, out_dtype=dtype)
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
x.shape[1],
self.config.rope_theta,
device=x.device)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
x = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)
return x, intermediate
class Llama2(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Llama2Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, embeddings):
self.model.embed_tokens = embeddings
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

18
comfy/text_encoders/lt.py Normal file
View File

@ -0,0 +1,18 @@
from comfy import sd1_clip
import os
from transformers import T5TokenizerFast
import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)

View File

@ -0,0 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

View File

@ -1,4 +1,3 @@
import os
import torch import torch
class SPieceTokenizer: class SPieceTokenizer:

View File

@ -172,7 +172,6 @@ class T5LayerSelfAttention(torch.nn.Module):
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, mask=None, past_bias=None, optimized_attention=None): def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
normed_hidden_states = self.layer_norm(x)
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention) output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
# x = x + self.dropout(attention_output) # x = x + self.dropout(attention_output)
x += output x += output
@ -209,6 +208,11 @@ class T5Stack(torch.nn.Module):
intermediate = None intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True) optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
past_bias = None past_bias = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.block) + intermediate_output
for i, l in enumerate(self.block): for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias, optimized_attention) x, past_bias = l(x, mask, past_bias, optimized_attention)
if i == intermediate_output: if i == intermediate_output:

View File

@ -26,6 +26,8 @@ import numpy as np
from PIL import Image from PIL import Image
import logging import logging
import itertools import itertools
from torch.nn.functional import interpolate
from einops import rearrange
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
@ -46,7 +48,13 @@ def load_torch_file(ckpt, safe_load=False, device=None):
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:
sd = pl_sd if len(pl_sd) == 1:
key = list(pl_sd.keys())[0]
sd = pl_sd[key]
if not isinstance(sd, dict):
sd = pl_sd
else:
sd = pl_sd
return sd return sd
def save_torch_file(sd, ckpt, metadata=None): def save_torch_file(sd, ckpt, metadata=None):
@ -316,10 +324,18 @@ MMDIT_MAP_BLOCK = {
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
("x_block.attn.proj.bias", "attn.to_out.0.bias"), ("x_block.attn.proj.bias", "attn.to_out.0.bias"),
("x_block.attn.proj.weight", "attn.to_out.0.weight"), ("x_block.attn.proj.weight", "attn.to_out.0.weight"),
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
("x_block.mlp.fc2.bias", "ff.net.2.bias"), ("x_block.mlp.fc2.bias", "ff.net.2.bias"),
@ -349,6 +365,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
k = "{}.attn2.".format(block_from)
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
for k in MMDIT_MAP_BLOCK: for k in MMDIT_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
@ -364,6 +386,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map return key_map
PIXART_MAP_BASIC = {
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
("y_embedder.y_embedding", "caption_projection.y_embedding"),
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
("t_block.1.weight", "adaln_single.linear.weight"),
("t_block.1.bias", "adaln_single.linear.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.scale_shift_table", "scale_shift_table"),
}
PIXART_MAP_BLOCK = {
("scale_shift_table", "scale_shift_table"),
("attn.proj.weight", "attn1.to_out.0.weight"),
("attn.proj.bias", "attn1.to_out.0.bias"),
("mlp.fc1.weight", "ff.net.0.proj.weight"),
("mlp.fc1.bias", "ff.net.0.proj.bias"),
("mlp.fc2.weight", "ff.net.2.weight"),
("mlp.fc2.bias", "ff.net.2.bias"),
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
}
def pixart_to_diffusers(mmdit_config, output_prefix=""):
key_map = {}
depth = mmdit_config.get("depth", 0)
offset = mmdit_config.get("hidden_size", 1152)
for i in range(depth):
block_from = "transformer_blocks.{}".format(i)
block_to = "{}blocks.{}".format(output_prefix, i)
for end in ("weight", "bias"):
s = "{}.attn1.".format(block_from)
qkv = "{}.attn.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
s = "{}.attn2.".format(block_from)
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = q
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
for k in PIXART_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
for k in PIXART_MAP_BASIC:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""): def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0) n_double_layers = mmdit_config.get("n_double_layers", 0)
@ -729,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return rows * cols return rows * cols
@torch.inference_mode() @torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
dims = len(tile) dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))): if not (isinstance(upscale_amount, (tuple, list))):
@ -738,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
if not (isinstance(overlap, (tuple, list))): if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val): def get_upscale(dim, val):
up = upscale_amount[dim] up = upscale_amount[dim]
if callable(up): if callable(up):
@ -745,10 +844,38 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
else: else:
return up * val return up * val
def get_downscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a): def mult_list_upscale(a):
out = [] out = []
for i in range(len(a)): for i in range(len(a)):
out.append(round(get_upscale(i, a[i]))) out.append(round(get_scale(i, a[i])))
return out return out
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
@ -766,23 +893,25 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions): for it in itertools.product(*positions):
s_in = s s_in = s
upscaled = [] upscaled = []
for d in range(dims): for d in range(dims):
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d])) pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos) l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l) s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_upscale(d, pos))) upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
for d in range(2, dims + 2): for d in range(2, dims + 2):
feather = round(get_upscale(d - 2, overlap[d - 2])) feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather): for t in range(feather):
a = (t + 1) / feather a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a) mask.narrow(d, t, 1).mul_(a)
@ -804,7 +933,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
return output return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
PROGRESS_BAR_ENABLED = True PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled): def set_progress_bar_enabled(enabled):
@ -834,3 +963,65 @@ class ProgressBar:
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)
def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2
if dims == 1:
scale_mode = "linear"
if dims == 2:
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "bilinear"
if dims == 3:
if len(input_mask.shape) < 5:
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "trilinear"
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = repeat_to_batch_size(mask, output_shape[0])
return mask
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in
ho, wo = img_size_out
# if it's already the correct size, no need to do anything
if (hi, wi) == (ho, wo):
return mask
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim != 3:
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
txt_tokens = mask.shape[1] - (hi * wi)
# quadrants of the mask
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
img_to_img = mask[:, txt_tokens:, txt_tokens:]
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
# convert to 1d x 2d, interpolate, then back to 1d x 1d
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
# this one is hard because we have to do it twice
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
# convert to 2d x 1d, interpolate, then back to 1d x 1d
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks
out = torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1
)
return out

Some files were not shown because too many files have changed in this diff Show More