mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge branch 'master' into feat/correct_mem_in_docker
This commit is contained in:
commit
0b5953d15d
@ -1,6 +1,9 @@
|
||||
import pygit2
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
import filecmp
|
||||
|
||||
def pull(repo, remote_name='origin', branch='master'):
|
||||
for remote in repo.remotes:
|
||||
@ -42,7 +45,8 @@ def pull(repo, remote_name='origin', branch='master'):
|
||||
raise AssertionError('Unknown merge analysis result')
|
||||
|
||||
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
||||
repo = pygit2.Repository(str(sys.argv[1]))
|
||||
repo_path = str(sys.argv[1])
|
||||
repo = pygit2.Repository(repo_path)
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
print("stashing current changes")
|
||||
@ -51,7 +55,10 @@ except KeyError:
|
||||
print("nothing to stash")
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name))
|
||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||
try:
|
||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||
except:
|
||||
pass
|
||||
|
||||
print("checking out master branch")
|
||||
branch = repo.lookup_branch('master')
|
||||
@ -63,3 +70,41 @@ pull(repo)
|
||||
|
||||
print("Done!")
|
||||
|
||||
self_update = True
|
||||
if len(sys.argv) > 2:
|
||||
self_update = '--skip_self_update' not in sys.argv
|
||||
|
||||
update_py_path = os.path.realpath(__file__)
|
||||
repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
|
||||
|
||||
cur_path = os.path.dirname(update_py_path)
|
||||
|
||||
|
||||
req_path = os.path.join(cur_path, "current_requirements.txt")
|
||||
repo_req_path = os.path.join(repo_path, "requirements.txt")
|
||||
|
||||
|
||||
def files_equal(file1, file2):
|
||||
try:
|
||||
return filecmp.cmp(file1, file2, shallow=False)
|
||||
except:
|
||||
return False
|
||||
|
||||
def file_size(f):
|
||||
try:
|
||||
return os.path.getsize(f)
|
||||
except:
|
||||
return 0
|
||||
|
||||
|
||||
if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
|
||||
shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
|
||||
exit()
|
||||
|
||||
if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
|
||||
shutil.copy(repo_req_path, req_path)
|
||||
except:
|
||||
pass
|
||||
|
@ -1,2 +1,8 @@
|
||||
@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
pause
|
||||
if exist update_new.py (
|
||||
move /y update_new.py update.py
|
||||
echo Running updater again since it got updated.
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
|
||||
)
|
||||
if "%~1"=="" pause
|
||||
|
@ -1,3 +0,0 @@
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 xformers -r ../ComfyUI/requirements.txt pygit2
|
||||
pause
|
@ -1,11 +0,0 @@
|
||||
@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
echo
|
||||
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
|
||||
echo You should not be running this anyways unless you really have to
|
||||
echo
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers -r ../ComfyUI/requirements.txt pygit2
|
||||
pause
|
45
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
45
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: Bug Report
|
||||
description: "Something is broken inside of ComfyUI. (Do not use this if you're just having issues and need help, or if the issue relates to a custom node)"
|
||||
labels: [ "Potential Bug" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
**1:** You are running the latest version of ComfyUI.
|
||||
**2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
**3:** This is an actual bug in ComfyUI, not just a support question and not caused by an custom node. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: "What you expected to happen."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: "What actually happened. Please include a screenshot of the issue if possible."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Debug Logs
|
||||
description: "Please copy the output from your terminal logs here."
|
||||
render: powershell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Other
|
||||
description: "Any other additional information you think might be helpful."
|
||||
validations:
|
||||
required: false
|
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: ComfyUI Matrix Space
|
||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||
- name: Comfy Org Discord
|
||||
url: https://discord.gg/comfyorg
|
||||
about: The Comfy Org Discord is available for support and general discussion related to ComfyUI.
|
32
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
32
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
name: Feature Request
|
||||
description: "You have an idea for something new you would like to see added to ComfyUI's core."
|
||||
labels: [ "Feature" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting a **Feature Request**, please ensure the following:
|
||||
|
||||
**1:** You are running the latest version of ComfyUI.
|
||||
**2:** You have looked to make sure there is not already a feature that does what you need, and there is not already a Feature Request listed for the same idea.
|
||||
**3:** This is something that makes sense to add to ComfyUI Core, and wouldn't make more sense as a custom node.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Feature Idea
|
||||
description: "Describe the feature you want to see."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Existing Solutions
|
||||
description: "Please search through available custom nodes / extensions to see if there are existing custom solutions for this. If so, please link the options you found here as a reference."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Other
|
||||
description: "Any other additional information you think might be helpful."
|
||||
validations:
|
||||
required: false
|
32
.github/ISSUE_TEMPLATE/user-support.yml
vendored
Normal file
32
.github/ISSUE_TEMPLATE/user-support.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
name: User Support
|
||||
description: "Use this if you need help with something, or you're experiencing an issue."
|
||||
labels: [ "User Support" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting a **User Report** issue, please ensure the following:
|
||||
|
||||
**1:** You are running the latest version of ComfyUI.
|
||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
description: "Post your question here. Please be as detailed as possible."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Logs
|
||||
description: "If your question relates to an issue you're experiencing, please go to `Server` -> `Logs` -> potentially set `View Type` to `Debug` as well, then copypaste all the text into here."
|
||||
render: powershell
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Other
|
||||
description: "Any other additional information you think might be helpful."
|
||||
validations:
|
||||
required: false
|
63
.github/workflows/test-browser.yml
vendored
Normal file
63
.github/workflows/test-browser.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
# This is a temporary action during frontend TS migration.
|
||||
# This file should be removed after TS migration is completed.
|
||||
# The browser test is here to ensure TS repo is working the same way as the
|
||||
# current JS code.
|
||||
# If you are adding UI feature, please sync your changes to the TS repo:
|
||||
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
|
||||
name: Playwright Browser Tests CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- name: Checkout ComfyUI_frontend
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "huchenlei/ComfyUI_frontend"
|
||||
path: "ComfyUI_frontend"
|
||||
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: lts/*
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install wait-for-it
|
||||
working-directory: ComfyUI
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
working-directory: ComfyUI
|
||||
- name: Install ComfyUI_frontend dependencies
|
||||
run: |
|
||||
npm ci
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Install Playwright Browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Run Playwright tests
|
||||
run: npx playwright test
|
||||
working-directory: ComfyUI_frontend
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: ComfyUI_frontend/playwright-report/
|
||||
retention-days: 30
|
@ -1,71 +0,0 @@
|
||||
name: "Windows Release cu118 dependencies"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
env:
|
||||
# you need at least cuda 5.0 for some of the stuff compiled here.
|
||||
TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 8.9"
|
||||
FORCE_CUDA: 1
|
||||
MAX_JOBS: 1 # will crash otherwise
|
||||
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
|
||||
XFORMERS_BUILD_TYPE: "Release"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Cache Built Dependencies
|
||||
uses: actions/cache@v3
|
||||
id: cache-cu118_python_stuff
|
||||
with:
|
||||
path: cu118_python_deps.tar
|
||||
key: ${{ runner.os }}-build-cu118
|
||||
|
||||
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10.9'
|
||||
|
||||
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||
uses: comfyanonymous/cuda-toolkit@test
|
||||
id: cuda-toolkit
|
||||
with:
|
||||
cuda: '11.8.0'
|
||||
# copied from xformers github
|
||||
- name: Setup MSVC
|
||||
uses: ilammy/msvc-dev-cmd@v1
|
||||
- name: Configure Pagefile
|
||||
# windows runners will OOM with many CUDA architectures
|
||||
# we cheat here with a page file
|
||||
uses: al-cheb/configure-pagefile-action@v1.3
|
||||
with:
|
||||
minimum-size: 2GB
|
||||
# really unfortunate: https://github.com/ilammy/msvc-dev-cmd#name-conflicts-with-shell-bash
|
||||
- name: Remove link.exe
|
||||
shell: bash
|
||||
run: rm /usr/bin/link
|
||||
|
||||
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
git clone --recurse-submodules https://github.com/facebookresearch/xformers.git
|
||||
cd xformers
|
||||
python -m pip install --no-cache-dir wheel setuptools twine
|
||||
echo building xformers
|
||||
python setup.py bdist_wheel -d ../temp_wheel_dir/
|
||||
cd ..
|
||||
rm -rf xformers
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir cu118_python_deps
|
||||
tar cf cu118_python_deps.tar cu118_python_deps
|
||||
|
||||
|
@ -1,37 +0,0 @@
|
||||
name: "Windows Release cu118 dependencies 2"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
xformers:
|
||||
description: 'xformers version'
|
||||
required: true
|
||||
type: string
|
||||
default: "xformers"
|
||||
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10.9'
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir cu118_python_deps
|
||||
tar cf cu118_python_deps.tar cu118_python_deps
|
||||
|
||||
- uses: actions/cache/save@v3
|
||||
with:
|
||||
path: cu118_python_deps.tar
|
||||
key: ${{ runner.os }}-build-cu118
|
@ -1,79 +0,0 @@
|
||||
name: "Windows Release cu118 packaging"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
jobs:
|
||||
package_comfyui:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/cache/restore@v3
|
||||
id: cache
|
||||
with:
|
||||
path: cu118_python_deps.tar
|
||||
key: ${{ runner.os }}-build-cu118
|
||||
- shell: bash
|
||||
run: |
|
||||
mv cu118_python_deps.tar ../
|
||||
cd ..
|
||||
tar xf cu118_python_deps.tar
|
||||
pwd
|
||||
ls
|
||||
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
cp -r ComfyUI ComfyUI_copy
|
||||
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip
|
||||
unzip python_embeded.zip -d python_embeded
|
||||
cd python_embeded
|
||||
echo 'import site' >> ./python310._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu118_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python310._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/update_windows_cu118/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
||||
tag: "latest"
|
||||
overwrite: true
|
||||
|
@ -24,7 +24,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "8"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -33,18 +33,17 @@ jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\\
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
|
||||
echo You should not be running this anyways unless you really have to
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
@ -59,7 +58,7 @@ jobs:
|
||||
mv temp_wheel_dir cu${{ inputs.cu }}_python_deps
|
||||
tar cf cu${{ inputs.cu }}_python_deps.tar cu${{ inputs.cu }}_python_deps
|
||||
|
||||
- uses: actions/cache/save@v3
|
||||
- uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
cu${{ inputs.cu }}_python_deps.tar
|
||||
|
@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "1"
|
||||
default: "3"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -32,11 +32,11 @@ jobs:
|
||||
pull-requests: "read"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
- shell: bash
|
||||
@ -49,7 +49,7 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
@ -68,12 +68,12 @@ jobs:
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
|
||||
echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\
|
||||
echo "call update_comfyui.bat nopause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||
|
||||
cd ComfyUI_windows_portable_nightly_pytorch
|
||||
|
@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "8"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -32,7 +32,7 @@ jobs:
|
||||
pull-requests: "read"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/cache/restore@v3
|
||||
- uses: actions/cache/restore@v4
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@ -48,7 +48,7 @@ jobs:
|
||||
pwd
|
||||
ls
|
||||
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@ -82,7 +82,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -9,10 +9,12 @@ __pycache__/
|
||||
!custom_nodes/example_node.py.example
|
||||
extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
/user/
|
||||
*.log
|
41
CONTRIBUTING.md
Normal file
41
CONTRIBUTING.md
Normal file
@ -0,0 +1,41 @@
|
||||
# Contributing to ComfyUI
|
||||
|
||||
Welcome, and thank you for your interest in contributing to ComfyUI!
|
||||
|
||||
There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved.
|
||||
|
||||
## Asking Questions
|
||||
|
||||
Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you.
|
||||
|
||||
## Providing Feedback
|
||||
|
||||
Your comments and feedback are welcome, and the development team is available via a handful of different channels.
|
||||
|
||||
See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord.
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible.
|
||||
|
||||
|
||||
### Look For an Existing Issue
|
||||
|
||||
Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed.
|
||||
|
||||
If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment:
|
||||
|
||||
* 👍 - upvote
|
||||
* 👎 - downvote
|
||||
|
||||
If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues.
|
||||
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
* Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project.
|
||||
|
||||
|
||||
## Thank You
|
||||
|
||||
Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute.
|
99
README.md
99
README.md
@ -11,16 +11,16 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
- Loading full workflows (with seeds) from generated PNG files.
|
||||
- Loading full workflows (with seeds) from generated PNG, WebP and FLAC files.
|
||||
- Saving/Loading workflows as Json files.
|
||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
|
||||
@ -41,29 +41,32 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Shortcuts
|
||||
|
||||
| Keybind | Explanation |
|
||||
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Delete/Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Keybind | Explanation |
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
|
||||
@ -99,11 +102,11 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
||||
|
||||
### NVIDIA
|
||||
|
||||
@ -113,7 +116,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@ -133,7 +136,16 @@ After this you should have everything installed and can proceed to running Comfy
|
||||
|
||||
### Others:
|
||||
|
||||
#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
|
||||
#### Intel GPUs
|
||||
|
||||
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
|
||||
|
||||
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
|
||||
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
|
||||
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
|
||||
#### Apple Mac silicon
|
||||
|
||||
@ -142,7 +154,7 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
||||
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
||||
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly.
|
||||
1. Launch ComfyUI by running `python main.py`
|
||||
|
||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||
|
||||
@ -195,30 +207,29 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
||||
```embedding:embedding_filename.pt```
|
||||
|
||||
|
||||
## How to increase generation speed?
|
||||
|
||||
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
|
||||
|
||||
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
|
||||
|
||||
```--dont-upcast-attention```
|
||||
|
||||
## How to show high-quality previews?
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
|
||||
## How to use TLS/SSL?
|
||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||
|
||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
# QA
|
||||
|
||||
### Why did you make this?
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
I wanted to learn how Stable Diffusion worked in detail. I also wanted something clean and powerful that would let me experiment with SD without restrictions.
|
||||
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
||||
|
||||
### Who is this for?
|
||||
|
||||
This is for anyone that wants to make complex workflows with SD or that wants to learn more how SD works. The interface follows closely how SD works and the code should be much more simple to understand than other SD UIs.
|
||||
|
@ -2,6 +2,8 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
from aiohttp import web
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import user_directory
|
||||
@ -56,16 +58,16 @@ class UserManager():
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
return None
|
||||
|
||||
parent = user_root
|
||||
|
||||
if file is not None:
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
return None
|
||||
|
||||
parent = os.path.split(path)[0]
|
||||
|
||||
if create_dir and not os.path.exists(parent):
|
||||
os.mkdir(parent)
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
@ -108,33 +110,96 @@ class UserManager():
|
||||
user_id = self.add_user(username)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
async def getuserdata(request):
|
||||
file = request.match_info.get("file", None)
|
||||
if not file:
|
||||
@routes.get("/userdata")
|
||||
async def listuserdata(request):
|
||||
directory = request.rel_url.query.get('dir', '')
|
||||
if not directory:
|
||||
return web.Response(status=400)
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
path = self.get_request_user_filepath(request, directory)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
if not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.FileResponse(path)
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
results = glob.glob(os.path.join(
|
||||
glob.escape(path), '**/*'), recursive=recurse)
|
||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
||||
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path:
|
||||
results = [[x] + x.split(os.sep) for x in results]
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
file = request.match_info.get("file", None)
|
||||
return web.json_response(results)
|
||||
|
||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
if check_exists and not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return path
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
async def getuserdata(request):
|
||||
path = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
return web.Response(status=200)
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
async def delete_userdata(request):
|
||||
path = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
os.remove(path)
|
||||
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
return web.json_response(resp)
|
||||
|
@ -13,7 +13,46 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ..ldm.util import exists
|
||||
from collections import OrderedDict
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = nhead
|
||||
self.c = c
|
||||
|
||||
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
q, k, v = x.split(self.c, dim=2)
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
class ResBlockUnionControlnet(nn.Module):
|
||||
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
|
||||
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
|
||||
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
|
||||
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
return self.attn(x)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
@ -52,6 +91,8 @@ class ControlNet(nn.Module):
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
attn_precision=None,
|
||||
union_controlnet=False,
|
||||
device=None,
|
||||
operations=comfy.ops.disable_weight_init,
|
||||
**kwargs,
|
||||
@ -202,7 +243,7 @@ class ControlNet(nn.Module):
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -262,7 +303,7 @@ class ControlNet(nn.Module):
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@ -279,6 +320,65 @@ class ControlNet(nn.Module):
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||
self._feature_size += ch
|
||||
|
||||
if union_controlnet:
|
||||
self.num_control_type = 6
|
||||
num_trans_channel = 320
|
||||
num_trans_head = 8
|
||||
num_trans_layer = 1
|
||||
num_proj_channel = 320
|
||||
# task_scale_factor = num_trans_channel ** 0.5
|
||||
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
|
||||
|
||||
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
|
||||
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
|
||||
#-----------------------------------------------------------------------------------------------------
|
||||
|
||||
control_add_embed_dim = 256
|
||||
class ControlAddEmbedding(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_control_type = num_control_type
|
||||
self.in_dim = in_dim
|
||||
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
|
||||
def forward(self, control_type, dtype, device):
|
||||
c_type = torch.zeros((self.num_control_type,), device=device)
|
||||
c_type[control_type] = 1.0
|
||||
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
|
||||
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
|
||||
|
||||
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.task_embedding = None
|
||||
self.control_add_embedding = None
|
||||
|
||||
def union_controlnet_merge(self, hint, control_type, emb, context):
|
||||
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
|
||||
inputs = []
|
||||
condition_list = []
|
||||
|
||||
for idx in range(min(1, len(control_type))):
|
||||
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
if idx < len(control_type):
|
||||
feat_seq += self.task_embedding[control_type[idx]]
|
||||
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(controlnet_cond)
|
||||
|
||||
x = torch.cat(inputs, dim=1)
|
||||
x = self.transformer_layes(x)
|
||||
controlnet_cond_fuser = None
|
||||
for idx in range(len(control_type)):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
o = condition_list[idx] + alpha
|
||||
if controlnet_cond_fuser is None:
|
||||
controlnet_cond_fuser = o
|
||||
else:
|
||||
controlnet_cond_fuser += o
|
||||
return controlnet_cond_fuser
|
||||
|
||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||
|
||||
@ -286,9 +386,21 @@ class ControlNet(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
guided_hint = None
|
||||
if self.control_add_embedding is not None: #Union Controlnet
|
||||
control_type = kwargs.get("control_type", [])
|
||||
|
||||
outs = []
|
||||
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
||||
if len(control_type) > 0:
|
||||
if len(hint.shape) < 5:
|
||||
hint = hint.unsqueeze(dim=0)
|
||||
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
|
||||
|
||||
if guided_hint is None:
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
out_output = []
|
||||
out_middle = []
|
||||
|
||||
hs = []
|
||||
if self.num_classes is not None:
|
||||
@ -303,10 +415,10 @@ class ControlNet(nn.Module):
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
out_output.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
out_middle.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
return {"middle": out_middle, "output": out_output}
|
||||
|
||||
|
77
comfy/cldm/mmdit.py
Normal file
77
comfy/cldm/mmdit.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
|
||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||
for _ in range(len(self.joint_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||
|
||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||
None,
|
||||
self.patch_size,
|
||||
self.in_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
hint = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
#weird sd3 controlnet specific stuff
|
||||
y = torch.zeros_like(y)
|
||||
|
||||
if self.context_processor is not None:
|
||||
context = self.context_processor(context)
|
||||
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||
x += self.pos_embed_input(hint)
|
||||
|
||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y)
|
||||
c = c + y
|
||||
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
output = []
|
||||
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
|
||||
out = self.controlnet_blocks[i](x)
|
||||
count = self.depth // blocks
|
||||
if i == blocks - 1:
|
||||
count -= 1
|
||||
for j in range(count):
|
||||
output.append(out)
|
||||
|
||||
return {"output": output}
|
@ -35,6 +35,8 @@ parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
|
||||
@ -49,7 +51,6 @@ cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
@ -74,6 +75,7 @@ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
@ -94,6 +96,11 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
upcast = parser.add_mutually_exclusive_group()
|
||||
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
@ -111,9 +118,13 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
@ -124,3 +135,10 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
import logging
|
||||
logging_level = logging.INFO
|
||||
if args.verbose:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
|
@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
x = self.embeddings(input_tokens)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.text_model(*args, **kwargs)
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
|
@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
clip = ClipVisionModel(json_config)
|
||||
m, u = clip.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
print("missing clip vision:", m)
|
||||
logging.warning("missing clip vision: {}".format(m))
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
|
@ -29,7 +29,12 @@ class CONDRegular:
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
|
||||
|
||||
|
@ -1,14 +1,18 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@ -35,18 +39,24 @@ class ControlBase:
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
self.timestep_percent_range = (0.0, 1.0)
|
||||
self.latent_format = None
|
||||
self.vae = None
|
||||
self.global_average_pooling = False
|
||||
self.timestep_range = None
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
if self.latent_format is not None:
|
||||
self.vae = vae
|
||||
return self
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
@ -77,43 +87,37 @@ class ControlBase:
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
c.global_average_pooling = self.global_average_pooling
|
||||
c.compression_ratio = self.compression_ratio
|
||||
c.upscale_algorithm = self.upscale_algorithm
|
||||
c.latent_format = self.latent_format
|
||||
c.vae = self.vae
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
return self.previous_controlnet.inference_memory_requirements(dtype)
|
||||
return 0
|
||||
|
||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
def control_merge(self, control, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
|
||||
if control_input is not None:
|
||||
for i in range(len(control_input)):
|
||||
key = 'input'
|
||||
x = control_input[i]
|
||||
if x is not None:
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
out[key].insert(0, x)
|
||||
|
||||
if control_output is not None:
|
||||
for key in control:
|
||||
control_output = control[key]
|
||||
applied_to = set()
|
||||
for i in range(len(control_output)):
|
||||
if i == (len(control_output) - 1):
|
||||
key = 'middle'
|
||||
index = 0
|
||||
else:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control_output[i]
|
||||
if x is not None:
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||
applied_to.add(x)
|
||||
x *= self.strength
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
|
||||
if control_prev is not None:
|
||||
for x in ['input', 'middle', 'output']:
|
||||
o = out[x]
|
||||
@ -128,18 +132,22 @@ class ControlBase:
|
||||
if o[i].shape[0] < prev_val.shape[0]:
|
||||
o[i] = prev_val + o[i]
|
||||
else:
|
||||
o[i] += prev_val
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
self.latent_format = latent_format
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@ -158,11 +166,21 @@ class ControlNet(ControlBase):
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
if self.latent_format is not None:
|
||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@ -174,10 +192,12 @@ class ControlNet(ControlBase):
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
@ -195,7 +215,7 @@ class ControlNet(ControlBase):
|
||||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -214,7 +234,7 @@ class ControlLoraOps:
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
@ -287,13 +307,13 @@ class ControlLora(ControlNet):
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
try:
|
||||
comfy.utils.set_attr(self.control_model, k, weight)
|
||||
comfy.utils.set_attr_param(self.control_model, k, weight)
|
||||
except:
|
||||
pass
|
||||
|
||||
for k in self.control_weights:
|
||||
if k not in {"lora_controlnet"}:
|
||||
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
|
||||
def copy(self):
|
||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||
@ -312,15 +332,49 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
|
||||
controlnet_config = model_config.unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@ -359,10 +413,18 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if k in controlnet_data:
|
||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||
|
||||
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
||||
controlnet_config["union_controlnet"] = True
|
||||
for k in list(controlnet_data.keys()):
|
||||
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
||||
new_sd[new_k] = controlnet_data.pop(k)
|
||||
|
||||
leftover_keys = controlnet_data.keys()
|
||||
if len(leftover_keys) > 0:
|
||||
print("leftover keys:", leftover_keys)
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@ -376,16 +438,24 @@ def load_controlnet(ckpt_path, model=None):
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
if net is None:
|
||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
@ -403,7 +473,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
cd = controlnet_data[x]
|
||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||
else:
|
||||
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
@ -412,7 +482,12 @@ def load_controlnet(ckpt_path, model=None):
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
@ -423,11 +498,13 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return control
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, device=None):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
self.t2i_model = t2i_model
|
||||
self.channels_in = channels_in
|
||||
self.control_input = None
|
||||
self.compression_ratio = compression_ratio
|
||||
self.upscale_algorithm = upscale_algorithm
|
||||
|
||||
def scale_image_to(self, width, height):
|
||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||
@ -447,13 +524,13 @@ class T2IAdapter(ControlBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.control_input = None
|
||||
self.cond_hint = None
|
||||
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
|
||||
width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
|
||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
@ -464,19 +541,21 @@ class T2IAdapter(ControlBase):
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.t2i_model.cpu()
|
||||
|
||||
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
||||
mid = None
|
||||
if self.t2i_model.xl == True:
|
||||
mid = control_input[-1:]
|
||||
control_input = control_input[:-1]
|
||||
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
|
||||
control_input = {}
|
||||
for k in self.control_input:
|
||||
control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
|
||||
|
||||
return self.control_merge(control_input, control_prev, x_noisy.dtype)
|
||||
|
||||
def copy(self):
|
||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
if 'adapter' in t2i_data:
|
||||
t2i_data = t2i_data['adapter']
|
||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
|
||||
@ -504,13 +583,22 @@ def load_t2i_adapter(t2i_data):
|
||||
if cin == 256 or cin == 768:
|
||||
xl = True
|
||||
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
elif "backbone.0.0.weight" in keys:
|
||||
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
compression_ratio = 32
|
||||
upscale_algorithm = 'bilinear'
|
||||
elif "backbone.10.blocks.0.weight" in keys:
|
||||
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
compression_ratio = 1
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
else:
|
||||
return None
|
||||
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||
if len(missing) > 0:
|
||||
print("t2i missing", missing)
|
||||
logging.warning("t2i missing {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
print("t2i unexpected", unexpected)
|
||||
logging.debug("t2i unexpected {}".format(unexpected))
|
||||
|
||||
return T2IAdapter(model_ad, model_ad.input_channels)
|
||||
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import re
|
||||
import torch
|
||||
import logging
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
logging.debug(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
return new_state_dict
|
||||
|
||||
@ -205,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||
def cat_tensors(tensors):
|
||||
x = 0
|
||||
for t in tensors:
|
||||
x += t.shape[0]
|
||||
|
||||
shape = [x] + list(tensors[0].shape)[1:]
|
||||
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
||||
|
||||
x = 0
|
||||
for t in tensors:
|
||||
out[x:x + t.shape[0]] = t
|
||||
x += t.shape[0]
|
||||
|
||||
return out
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
@ -237,20 +253,24 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
text_proj = "transformer.text_projection.weight"
|
||||
if k.endswith(text_proj):
|
||||
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
|
||||
else:
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
@ -358,9 +358,6 @@ class UniPC:
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
variant='bh1',
|
||||
noise_mask=None,
|
||||
masked_image=None,
|
||||
noise=None,
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
@ -372,9 +369,6 @@ class UniPC:
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
self.noise_mask = noise_mask
|
||||
self.masked_image = masked_image
|
||||
self.noise = noise
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
@ -391,10 +385,7 @@ class UniPC:
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
if self.noise_mask is not None:
|
||||
return self.model(x, t) * self.noise_mask
|
||||
else:
|
||||
return self.model(x, t)
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
@ -409,8 +400,6 @@ class UniPC:
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
if self.noise_mask is not None:
|
||||
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
@ -723,8 +712,6 @@ class UniPC:
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
# with torch.no_grad():
|
||||
for step_index in trange(steps, disable=disable_pbar):
|
||||
if self.noise_mask is not None:
|
||||
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
||||
if step_index == 0:
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
@ -766,7 +753,7 @@ class UniPC:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
if callback is not None:
|
||||
callback(step_index, model_prev_list[-1], x, steps)
|
||||
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
# if denoise_to_zero:
|
||||
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
||||
timesteps = sigmas.clone()
|
||||
if sigmas[-1] == 0:
|
||||
timesteps = sigmas[:]
|
||||
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
|
||||
timesteps = sigmas.clone()
|
||||
ns = SigmaConvert()
|
||||
|
||||
if image is not None:
|
||||
img = image * ns.marginal_alpha(timesteps[0])
|
||||
if max_denoise:
|
||||
noise_mult = 1.0
|
||||
else:
|
||||
noise_mult = ns.marginal_std(timesteps[0])
|
||||
img += noise * noise_mult
|
||||
else:
|
||||
img = noise
|
||||
|
||||
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
|
||||
)
|
||||
|
||||
order = min(3, len(timesteps) - 2)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
|
||||
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
return x
|
||||
|
||||
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
@ -2,7 +2,8 @@ import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -22,7 +23,7 @@ def default(val, d):
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
@ -35,14 +36,14 @@ class FeedForward(nn.Module):
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
|
||||
query_dim=query_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
dim_head=d_head,
|
||||
operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
self.linear = ops.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=query_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
dim_head=d_head,
|
||||
operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
self.linear = ops.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -201,11 +204,11 @@ class PositionNet(nn.Module):
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||
|
||||
self.linears = nn.Sequential(
|
||||
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||
ops.Linear(self.in_dim + self.position_dim, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, 512),
|
||||
ops.Linear(512, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, out_dim),
|
||||
ops.Linear(512, out_dim),
|
||||
)
|
||||
|
||||
self.null_positive_feature = torch.nn.Parameter(
|
||||
@ -215,16 +218,15 @@ class PositionNet(nn.Module):
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
dtype = self.linears[0].weight.dtype
|
||||
masks = masks.unsqueeze(-1).to(dtype)
|
||||
positive_embeddings = positive_embeddings.to(dtype)
|
||||
masks = masks.unsqueeze(-1)
|
||||
positive_embeddings = positive_embeddings
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
||||
|
||||
# replace padding with learnable null embedding
|
||||
positive_embeddings = positive_embeddings * \
|
||||
@ -251,7 +253,7 @@ class Gligen(nn.Module):
|
||||
def func(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
|
121
comfy/k_diffusion/deis.py
Normal file
121
comfy/k_diffusion/deis.py
Normal file
@ -0,0 +1,121 @@
|
||||
#Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
|
||||
#under Apache 2 license
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
|
||||
#############################
|
||||
### Utils for DEIS solver ###
|
||||
#############################
|
||||
#----------------------------------------------------------------------------
|
||||
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
||||
|
||||
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
||||
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
||||
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
||||
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
||||
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
||||
t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
|
||||
return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def cal_poly(prev_t, j, taus):
|
||||
poly = 1
|
||||
for k in range(prev_t.shape[0]):
|
||||
if k == j:
|
||||
continue
|
||||
poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
|
||||
return poly
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Transfer from t to alpha_t.
|
||||
|
||||
def t2alpha_fn(beta_0, beta_1, t):
|
||||
return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def cal_intergrand(beta_0, beta_1, taus):
|
||||
with torch.inference_mode(mode=False):
|
||||
taus = taus.clone()
|
||||
beta_0 = beta_0.clone()
|
||||
beta_1 = beta_1.clone()
|
||||
with torch.enable_grad():
|
||||
taus.requires_grad_(True)
|
||||
alpha = t2alpha_fn(beta_0, beta_1, taus)
|
||||
log_alpha = alpha.log()
|
||||
log_alpha.sum().backward()
|
||||
d_log_alpha_dtau = taus.grad
|
||||
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
|
||||
return integrand
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
|
||||
"""
|
||||
Get the coefficient list for DEIS sampling.
|
||||
|
||||
Args:
|
||||
t_steps: A pytorch tensor. The time steps for sampling.
|
||||
max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
|
||||
N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
|
||||
deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
|
||||
Returns:
|
||||
A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
|
||||
"""
|
||||
if deis_mode == 'tab':
|
||||
t_steps, beta_0, beta_1 = edm2t(t_steps)
|
||||
C = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
|
||||
order = min(i+1, max_order)
|
||||
if order == 1:
|
||||
C.append([])
|
||||
else:
|
||||
taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
|
||||
dtau = (t_next - t_cur) / N
|
||||
prev_t = t_steps[[i - k for k in range(order)]]
|
||||
coeff_temp = []
|
||||
integrand = cal_intergrand(beta_0, beta_1, taus)
|
||||
for j in range(order):
|
||||
poly = cal_poly(prev_t, j, taus)
|
||||
coeff_temp.append(torch.sum(integrand * poly) * dtau)
|
||||
C.append(coeff_temp)
|
||||
|
||||
elif deis_mode == 'rhoab':
|
||||
# Analytical solution, second order
|
||||
def get_def_intergral_2(a, b, start, end, c):
|
||||
coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
|
||||
return coeff / ((c - a) * (c - b))
|
||||
|
||||
# Analytical solution, third order
|
||||
def get_def_intergral_3(a, b, c, start, end, d):
|
||||
coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
|
||||
+ (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
|
||||
return coeff / ((d - a) * (d - b) * (d - c))
|
||||
|
||||
C = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
|
||||
order = min(i, max_order)
|
||||
if order == 0:
|
||||
C.append([])
|
||||
else:
|
||||
prev_t = t_steps[[i - k for k in range(order+1)]]
|
||||
if order == 1:
|
||||
coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
|
||||
coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
|
||||
coeff_temp = [coeff_cur, coeff_prev1]
|
||||
elif order == 2:
|
||||
coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
|
||||
coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
|
||||
coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
|
||||
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
|
||||
elif order == 3:
|
||||
coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
|
||||
coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
|
||||
coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
|
||||
coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
|
||||
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
|
||||
C.append(coeff_temp)
|
||||
return C
|
||||
|
@ -7,7 +7,8 @@ import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
|
||||
from . import deis
|
||||
import comfy.model_patcher
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@ -129,8 +130,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
else:
|
||||
gamma = 0
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
@ -170,7 +176,13 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
else:
|
||||
gamma = 0
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
@ -199,8 +211,13 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
else:
|
||||
gamma = 0
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
@ -527,6 +544,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
@ -595,6 +615,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
@ -642,6 +664,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
@ -690,18 +715,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
@ -748,7 +782,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
|
||||
return x
|
||||
|
||||
|
||||
@ -808,3 +842,209 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
x_next = x
|
||||
|
||||
buffer_model = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
t_cur = sigmas[i]
|
||||
t_next = sigmas[i + 1]
|
||||
|
||||
x_cur = x_next
|
||||
|
||||
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
elif order == 3: # Use two history points.
|
||||
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
|
||||
elif order == 4: # Use three history points.
|
||||
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[-1] = d_cur
|
||||
else:
|
||||
buffer_model.append(d_cur)
|
||||
|
||||
return x_next
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
x_next = x
|
||||
t_steps = sigmas
|
||||
|
||||
buffer_model = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
t_cur = sigmas[i]
|
||||
t_next = sigmas[i + 1]
|
||||
|
||||
x_cur = x_next
|
||||
|
||||
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2
|
||||
coeff2 = -(h_n / h_n_1) / 2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
|
||||
elif order == 3: # Use two history points.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
||||
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
|
||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
|
||||
coeff3 = temp * h_n_1 / h_n_2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
|
||||
elif order == 4: # Use three history points.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
||||
h_n_3 = (t_steps[i-2] - t_steps[i-3])
|
||||
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
|
||||
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
|
||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
|
||||
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
|
||||
coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[-1] = d_cur.detach()
|
||||
else:
|
||||
buffer_model.append(d_cur.detach())
|
||||
|
||||
return x_next
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
@torch.no_grad()
|
||||
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
x_next = x
|
||||
t_steps = sigmas
|
||||
|
||||
coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
|
||||
|
||||
buffer_model = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
t_cur = sigmas[i]
|
||||
t_next = sigmas[i + 1]
|
||||
|
||||
x_cur = x_next
|
||||
|
||||
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if t_next <= 0:
|
||||
order = 1
|
||||
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
coeff_cur, coeff_prev1 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
|
||||
elif order == 3: # Use two history points.
|
||||
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
|
||||
elif order == 4: # Use three history points.
|
||||
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[-1] = d_cur.detach()
|
||||
else:
|
||||
buffer_model.append(d_cur.detach())
|
||||
|
||||
return x_next
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
|
||||
class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
@ -23,8 +25,9 @@ class SD15(LatentFormat):
|
||||
self.taesd_decoder_name = "taesd_decoder"
|
||||
|
||||
class SDXL(LatentFormat):
|
||||
scale_factor = 0.13025
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.13025
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 0.3920, 0.4054, 0.4549],
|
||||
@ -34,6 +37,105 @@ class SDXL(LatentFormat):
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SDXL_Playground_2_5(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.5
|
||||
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
|
||||
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
|
||||
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
|
||||
class SD_X4(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.08333
|
||||
self.latent_rgb_factors = [
|
||||
[-0.2340, -0.3863, -0.3257],
|
||||
[ 0.0994, 0.0885, -0.0908],
|
||||
[-0.2833, -0.2349, -0.3741],
|
||||
[ 0.2523, -0.0055, -0.1651]
|
||||
]
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
latent_channels = 16
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0326, -0.0204, -0.0127],
|
||||
[-0.1592, -0.0427, 0.0216],
|
||||
[ 0.0873, 0.0638, -0.0020],
|
||||
[-0.0602, 0.0442, 0.1304],
|
||||
[ 0.0800, -0.0313, -0.1796],
|
||||
[-0.0810, -0.0638, -0.1581],
|
||||
[ 0.1791, 0.1180, 0.0967],
|
||||
[ 0.0740, 0.1416, 0.0432],
|
||||
[-0.1745, -0.1888, -0.1373],
|
||||
[ 0.2412, 0.1577, 0.0928],
|
||||
[ 0.1908, 0.0998, 0.0682],
|
||||
[ 0.0209, 0.0365, -0.0092],
|
||||
[ 0.0448, -0.0650, -0.1728],
|
||||
[-0.1658, -0.1045, -0.1308],
|
||||
[ 0.0542, 0.1545, 0.1325],
|
||||
[-0.0352, -0.1672, -0.2541]
|
||||
]
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
[ 0.1121, 0.2006, 0.1023],
|
||||
[-0.2093, -0.0222, -0.0195],
|
||||
[-0.3087, -0.1535, 0.0366],
|
||||
[ 0.0290, -0.1574, -0.4078]
|
||||
]
|
||||
|
||||
class SD3(LatentFormat):
|
||||
latent_channels = 16
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360],
|
||||
[ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259]
|
||||
]
|
||||
self.taesd_decoder_name = "taesd3_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
282
comfy/ldm/audio/autoencoder.py
Normal file
282
comfy/ldm/audio/autoencoder.py
Normal file
@ -0,0 +1,282 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Literal, Dict, Any
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.is_discrete = False
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def snake_beta(x, alpha, beta):
|
||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||
class SnakeBeta(nn.Module):
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
# self.alpha.requires_grad = alpha_trainable
|
||||
# self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = torch.nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = torch.nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class AudioOobleckVAE(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=64,
|
||||
c_mults = [1, 2, 4, 8, 16],
|
||||
strides = [2, 4, 4, 8, 8],
|
||||
use_snake=True,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=False):
|
||||
super().__init__()
|
||||
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
||||
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
||||
self.bottleneck = VAEBottleneck()
|
||||
|
||||
def encode(self, x):
|
||||
return self.bottleneck.encode(self.encoder(x))
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.bottleneck.decode(x))
|
||||
|
888
comfy/ldm/audio/dit.py
Normal file
888
comfy/ldm/audio/dit.py
Normal file
@ -0,0 +1,888 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.weight = nn.Parameter(torch.empty(
|
||||
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
# norms
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
|
||||
"""
|
||||
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||
|
||||
if bias:
|
||||
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.beta = None
|
||||
|
||||
def forward(self, x):
|
||||
beta = self.beta
|
||||
if self.beta is not None:
|
||||
beta = beta.to(dtype=x.dtype, device=x.device)
|
||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
activation,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
|
||||
self.use_conv = use_conv
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
else:
|
||||
x = self.proj(x)
|
||||
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * self.act(gate)
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.max_seq_len = max_seq_len
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||
|
||||
pos_emb = self.emb(pos)
|
||||
pos_emb = pos_emb * self.scale
|
||||
return pos_emb
|
||||
|
||||
class ScaledSinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta = 10000):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||
|
||||
half_dim = dim // 2
|
||||
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||
inv_freq = theta ** -freq_seq
|
||||
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = pos - seq_start_pos[..., None]
|
||||
|
||||
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return emb * self.scale
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
use_xpos = False,
|
||||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer('scale', None)
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
|
||||
self.scale_base = scale_base
|
||||
self.register_buffer('scale', scale)
|
||||
|
||||
def forward_from_seq_len(self, seq_len, device, dtype):
|
||||
# device = self.inv_freq.device
|
||||
|
||||
t = torch.arange(seq_len, device=device, dtype=dtype)
|
||||
return self.forward(t)
|
||||
|
||||
def forward(self, t):
|
||||
# device = self.inv_freq.device
|
||||
device = t.device
|
||||
dtype = t.dtype
|
||||
|
||||
# t = t.to(torch.float32)
|
||||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||
x1, x2 = x.unbind(dim = -2)
|
||||
return torch.cat((-x2, x1), dim = -1)
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||
out_dtype = t.dtype
|
||||
|
||||
# cast to float32 if necessary for numerical stability
|
||||
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||
freqs = freqs[-seq_len:, :]
|
||||
|
||||
if t.ndim == 4 and freqs.ndim == 3:
|
||||
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||
|
||||
# partial rotary embeddings, Wang et al. GPT-J
|
||||
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||
|
||||
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||
|
||||
return torch.cat((t, t_unrotated), dim = -1)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out = None,
|
||||
mult = 4,
|
||||
no_bias = False,
|
||||
glu = True,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
zero_init_output = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
# Default to SwiGLU
|
||||
|
||||
activation = nn.SiLU()
|
||||
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
|
||||
if glu:
|
||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
|
||||
|
||||
# # init last linear layer to 0
|
||||
# if zero_init_output:
|
||||
# nn.init.zeros_(linear_out.weight)
|
||||
# if not no_bias:
|
||||
# nn.init.zeros_(linear_out.bias)
|
||||
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
dim_context = None,
|
||||
causal = False,
|
||||
zero_init_output=True,
|
||||
qk_norm = False,
|
||||
natten_kernel_size = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.causal = causal
|
||||
|
||||
dim_kv = dim_context if dim_context is not None else dim
|
||||
|
||||
self.num_heads = dim // dim_heads
|
||||
self.kv_heads = dim_kv // dim_heads
|
||||
|
||||
if dim_context is not None:
|
||||
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
# if zero_init_output:
|
||||
# nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
kv_input = context if has_context else x
|
||||
|
||||
if hasattr(self, 'to_q'):
|
||||
# Use separate linear projections for q and k/v
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||
else:
|
||||
# Use fused linear projection
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# Normalize q and k for cosine sim attention
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
if rotary_pos_emb is not None and not has_context:
|
||||
freqs, _ = rotary_pos_emb
|
||||
|
||||
q_dtype = q.dtype
|
||||
k_dtype = k.dtype
|
||||
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
freqs = freqs.to(torch.float32)
|
||||
|
||||
q = apply_rotary_pos_emb(q, freqs)
|
||||
k = apply_rotary_pos_emb(k, freqs)
|
||||
|
||||
q = q.to(q_dtype)
|
||||
k = k.to(k_dtype)
|
||||
|
||||
input_mask = context_mask
|
||||
|
||||
if input_mask is None and not has_context:
|
||||
input_mask = mask
|
||||
|
||||
# determine masking
|
||||
masks = []
|
||||
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||
masks.append(~input_mask)
|
||||
|
||||
# Other masks will be added here later
|
||||
|
||||
if len(masks) > 0:
|
||||
final_attn_mask = ~or_reduce(masks)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if h != kv_h:
|
||||
# Repeat interleave kv_heads to match q_heads
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b n -> b n 1')
|
||||
out = out.masked_fill(~mask, 0.)
|
||||
|
||||
return out
|
||||
|
||||
class ConformerModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
norm_kwargs = {},
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
||||
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
self.glu = GLU(dim, dim, nn.SiLU())
|
||||
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||
self.swish = nn.SiLU()
|
||||
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_norm(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.glu(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.depthwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.mid_norm(x)
|
||||
x = self.swish(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv_2(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
cross_attend = False,
|
||||
dim_context = None,
|
||||
global_cond_dim = None,
|
||||
causal = False,
|
||||
zero_init_branch_outputs = True,
|
||||
conformer = False,
|
||||
layer_ix = -1,
|
||||
remove_norms = False,
|
||||
attn_kwargs = {},
|
||||
ff_kwargs = {},
|
||||
norm_kwargs = {},
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.cross_attend = cross_attend
|
||||
self.dim_context = dim_context
|
||||
self.causal = causal
|
||||
|
||||
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||
|
||||
self.self_attn = Attention(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
if cross_attend:
|
||||
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||
self.cross_attn = Attention(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
dim_context=dim_context,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
|
||||
|
||||
self.layer_ix = layer_ix
|
||||
|
||||
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
||||
|
||||
self.global_cond_dim = global_cond_dim
|
||||
|
||||
if global_cond_dim is not None:
|
||||
self.to_scale_shift_gate = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
||||
)
|
||||
|
||||
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
||||
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
||||
|
||||
# self-attention with adaLN
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
|
||||
# feedforward with adaLN
|
||||
residual = x
|
||||
x = self.ff_norm(x)
|
||||
x = x * (1 + scale_ff) + shift_ff
|
||||
x = self.ff(x)
|
||||
x = x * torch.sigmoid(1 - gate_ff)
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
|
||||
x = x + self.ff(self.ff_norm(x))
|
||||
|
||||
return x
|
||||
|
||||
class ContinuousTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
*,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
dim_heads = 64,
|
||||
cross_attend=False,
|
||||
cond_token_dim=None,
|
||||
global_cond_dim=None,
|
||||
causal=False,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_outputs=True,
|
||||
conformer=False,
|
||||
use_sinusoidal_emb=False,
|
||||
use_abs_pos_emb=False,
|
||||
abs_pos_emb_max_length=10000,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.causal = causal
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
|
||||
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
|
||||
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||
if use_sinusoidal_emb:
|
||||
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||
|
||||
self.use_abs_pos_emb = use_abs_pos_emb
|
||||
if use_abs_pos_emb:
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TransformerBlock(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
cross_attend = cross_attend,
|
||||
dim_context = cond_token_dim,
|
||||
global_cond_dim = global_cond_dim,
|
||||
causal = causal,
|
||||
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||
conformer=conformer,
|
||||
layer_ix=i,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None,
|
||||
prepend_embeds = None,
|
||||
prepend_mask = None,
|
||||
global_cond = None,
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
}
|
||||
|
||||
x = self.project_in(x)
|
||||
|
||||
if prepend_embeds is not None:
|
||||
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||
|
||||
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||
|
||||
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||
|
||||
if prepend_mask is not None or mask is not None:
|
||||
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
||||
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
||||
|
||||
mask = torch.cat((prepend_mask, mask), dim = -1)
|
||||
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
info["hidden_states"].append(x)
|
||||
|
||||
x = self.project_out(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
class AudioDiffusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
io_channels=64,
|
||||
patch_size=1,
|
||||
embed_dim=1536,
|
||||
cond_token_dim=768,
|
||||
project_cond_tokens=False,
|
||||
global_cond_dim=1536,
|
||||
project_global_cond=True,
|
||||
input_concat_dim=0,
|
||||
prepend_cond_dim=0,
|
||||
depth=24,
|
||||
num_heads=24,
|
||||
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||
audio_model="",
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.cond_token_dim = cond_token_dim
|
||||
|
||||
# Timestep embeddings
|
||||
timestep_features_dim = 256
|
||||
|
||||
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_timestep_embed = nn.Sequential(
|
||||
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
if cond_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
|
||||
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_cond_embed = nn.Sequential(
|
||||
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 0
|
||||
|
||||
if global_cond_dim > 0:
|
||||
# Global conditioning
|
||||
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||
self.to_global_embed = nn.Sequential(
|
||||
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
if prepend_cond_dim > 0:
|
||||
# Prepend conditioning
|
||||
self.to_prepend_embed = nn.Sequential(
|
||||
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
self.input_concat_dim = input_concat_dim
|
||||
|
||||
dim_in = io_channels + self.input_concat_dim
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Transformer
|
||||
|
||||
self.transformer_type = transformer_type
|
||||
|
||||
self.global_cond_type = global_cond_type
|
||||
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
|
||||
global_dim = None
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
# The global conditioning is projected to the embed_dim already at this point
|
||||
global_dim = embed_dim
|
||||
|
||||
self.transformer = ContinuousTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
dim_heads=embed_dim // num_heads,
|
||||
dim_in=dim_in * patch_size,
|
||||
dim_out=io_channels * patch_size,
|
||||
cross_attend = cond_token_dim > 0,
|
||||
cond_token_dim = cond_embed_dim,
|
||||
global_cond_dim=global_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||
|
||||
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
|
||||
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
mask=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||
|
||||
if global_embed is not None:
|
||||
# Project the global conditioning to the embedding dimension
|
||||
global_embed = self.to_global_embed(global_embed)
|
||||
|
||||
prepend_inputs = None
|
||||
prepend_mask = None
|
||||
prepend_length = 0
|
||||
if prepend_cond is not None:
|
||||
# Project the prepend conditioning to the embedding dimension
|
||||
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||
|
||||
prepend_inputs = prepend_cond
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_mask = prepend_cond_mask
|
||||
|
||||
if input_concat_cond is not None:
|
||||
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
if input_concat_cond.shape[2] != x.shape[2]:
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
|
||||
x = torch.cat([x, input_concat_cond], dim=1)
|
||||
|
||||
# Get the batch of timestep embeddings
|
||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
|
||||
|
||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||
if global_embed is not None:
|
||||
global_embed = global_embed + timestep_embed
|
||||
else:
|
||||
global_embed = timestep_embed
|
||||
|
||||
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||
if self.global_cond_type == "prepend":
|
||||
if prepend_inputs is None:
|
||||
# Prepend inputs are just the global embed, and the mask is all ones
|
||||
prepend_inputs = global_embed.unsqueeze(1)
|
||||
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||
else:
|
||||
# Prepend inputs are the prepend conditioning + the global embed
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||
|
||||
prepend_length = prepend_inputs.shape[1]
|
||||
|
||||
x = self.preprocess_conv(x) + x
|
||||
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
|
||||
extra_args = {}
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
extra_args["global_cond"] = global_embed
|
||||
|
||||
if self.patch_size > 1:
|
||||
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||
|
||||
if self.transformer_type == "x-transformers":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
||||
elif self.transformer_type == "continuous_transformer":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||
|
||||
if return_info:
|
||||
output, info = output
|
||||
elif self.transformer_type == "mm_transformer":
|
||||
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
||||
|
||||
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||
|
||||
if self.patch_size > 1:
|
||||
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||
|
||||
output = self.postprocess_conv(output) + output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context=None,
|
||||
context_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
negative_global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
timestep,
|
||||
cross_attn_cond=context,
|
||||
cross_attn_cond_mask=context_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
global_embed=global_embed,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
mask=mask,
|
||||
return_info=return_info,
|
||||
**kwargs
|
||||
)
|
108
comfy/ldm/audio/embedders.py
Normal file
108
comfy/ldm/audio/embedders.py
Normal file
@ -0,0 +1,108 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, einsum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from einops import rearrange
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
class LearnedPositionalEmbedding(nn.Module):
|
||||
"""Used for continuous time"""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0
|
||||
half_dim = dim // 2
|
||||
self.weights = nn.Parameter(torch.empty(half_dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = rearrange(x, "b -> b 1")
|
||||
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
||||
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
||||
fouriered = torch.cat((x, fouriered), dim=-1)
|
||||
return fouriered
|
||||
|
||||
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
||||
return nn.Sequential(
|
||||
LearnedPositionalEmbedding(dim),
|
||||
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
||||
)
|
||||
|
||||
|
||||
class NumberEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features: int,
|
||||
dim: int = 256,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
||||
|
||||
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
||||
if not torch.is_tensor(x):
|
||||
device = next(self.embedding.parameters()).device
|
||||
x = torch.tensor(x, device=device)
|
||||
assert isinstance(x, Tensor)
|
||||
shape = x.shape
|
||||
x = rearrange(x, "... -> (...)")
|
||||
embedding = self.embedding(x)
|
||||
x = embedding.view(*shape, self.features)
|
||||
return x # type: ignore
|
||||
|
||||
|
||||
class Conditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
output_dim: int,
|
||||
project_out: bool = False
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.output_dim = output_dim
|
||||
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
class NumberConditioner(Conditioner):
|
||||
'''
|
||||
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
||||
'''
|
||||
def __init__(self,
|
||||
output_dim: int,
|
||||
min_val: float=0,
|
||||
max_val: float=1
|
||||
):
|
||||
super().__init__(output_dim, output_dim)
|
||||
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
|
||||
self.embedder = NumberEmbedder(features=output_dim)
|
||||
|
||||
def forward(self, floats, device=None):
|
||||
# Cast the inputs to floats
|
||||
floats = [float(x) for x in floats]
|
||||
|
||||
if device is None:
|
||||
device = next(self.embedder.parameters()).device
|
||||
|
||||
floats = torch.tensor(floats).to(device)
|
||||
|
||||
floats = floats.clamp(self.min_val, self.max_val)
|
||||
|
||||
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
||||
|
||||
# Cast floats to same type as embedder
|
||||
embedder_dtype = next(self.embedder.parameters()).dtype
|
||||
normalized_floats = normalized_floats.to(embedder_dtype)
|
||||
|
||||
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
||||
|
||||
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
161
comfy/ldm/cascade/common.py
Normal file
161
comfy/ldm/cascade/common.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = nhead
|
||||
|
||||
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
class Attention2D(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def LayerNorm2d_op(operations):
|
||||
class LayerNorm2d(operations.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return LayerNorm2d
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||
def __init__(self, dim, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
||||
super().__init__()
|
||||
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
||||
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
||||
self.kv_mapper = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
93
comfy/ldm/cascade/controlnet.py
Normal file
93
comfy/ldm/cascade/controlnet.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from .common import LayerNorm2d_op
|
||||
|
||||
|
||||
class CNetResBlock(nn.Module):
|
||||
def __init__(self, c, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
||||
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.blocks(x)
|
||||
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
|
||||
super().__init__()
|
||||
if bottleneck_mode is None:
|
||||
bottleneck_mode = 'effnet'
|
||||
self.proj_blocks = proj_blocks
|
||||
if bottleneck_mode == 'effnet':
|
||||
embd_channels = 1280
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
if c_in != 3:
|
||||
in_weights = self.backbone[0][0].weight.data
|
||||
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
|
||||
if c_in > 3:
|
||||
# nn.init.constant_(self.backbone[0][0].weight, 0)
|
||||
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
|
||||
else:
|
||||
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
|
||||
elif bottleneck_mode == 'simple':
|
||||
embd_channels = c_in
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
elif bottleneck_mode == 'large':
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
|
||||
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
|
||||
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
|
||||
)
|
||||
embd_channels = 1280
|
||||
else:
|
||||
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
|
||||
self.projections = nn.ModuleList()
|
||||
for _ in range(len(proj_blocks)):
|
||||
self.projections.append(nn.Sequential(
|
||||
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
|
||||
))
|
||||
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
|
||||
self.xl = False
|
||||
self.input_channels = c_in
|
||||
self.unshuffle_amount = 8
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
|
||||
for i, idx in enumerate(self.proj_blocks):
|
||||
proj_outputs[idx] = self.projections[i](x)
|
||||
return {"input": proj_outputs[::-1]}
|
255
comfy/ldm/cascade/stage_a.py
Normal file
255
comfy/ldm/cascade/stage_a.py
Normal file
@ -0,0 +1,255 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, codebook):
|
||||
with torch.no_grad():
|
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
||||
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
||||
|
||||
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
||||
_, indices = dist.min(dim=1)
|
||||
|
||||
ctx.save_for_backward(indices, codebook)
|
||||
ctx.mark_non_differentiable(indices)
|
||||
|
||||
nn = torch.index_select(codebook, 0, indices)
|
||||
return nn, indices
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_indices):
|
||||
grad_inputs, grad_codebook = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_inputs = grad_output.clone()
|
||||
if ctx.needs_input_grad[1]:
|
||||
# Gradient wrt. the codebook
|
||||
indices, codebook = ctx.saved_tensors
|
||||
|
||||
grad_codebook = torch.zeros_like(codebook)
|
||||
grad_codebook.index_add_(0, indices, grad_output)
|
||||
|
||||
return (grad_inputs, grad_codebook)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
||||
"""
|
||||
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
||||
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
||||
with the same size as the input, vq and commitment components for the loss as a touple
|
||||
in the second output and the indices of the quantized vectors in the third:
|
||||
quantized, (vq_loss, commit_loss), indices
|
||||
"""
|
||||
super(VectorQuantize, self).__init__()
|
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size)
|
||||
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
||||
self.vq = vector_quantize.apply
|
||||
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_loss = ema_loss
|
||||
if ema_loss:
|
||||
self.register_buffer('ema_element_count', torch.ones(k))
|
||||
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
||||
|
||||
def _laplace_smoothing(self, x, epsilon):
|
||||
n = torch.sum(x)
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
if dim != -1:
|
||||
q_idx = q_idx.movedim(-1, dim)
|
||||
return q_idx
|
||||
|
||||
def forward(self, x, get_losses=True, dim=-1):
|
||||
if dim != -1:
|
||||
x = x.movedim(dim, -1)
|
||||
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
||||
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
||||
vq_loss, commit_loss = None, None
|
||||
if self.ema_loss and self.training:
|
||||
self._updateEMA(z_e_x.detach(), indices.detach())
|
||||
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
||||
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
||||
if get_losses:
|
||||
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
||||
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
||||
|
||||
z_q_x = z_q_x.view(x.shape)
|
||||
if dim != -1:
|
||||
z_q_x = z_q_x.movedim(-1, dim)
|
||||
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_hidden):
|
||||
super().__init__()
|
||||
# depthwise/attention
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
mods = self.gammas
|
||||
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
||||
try:
|
||||
x = x + self.depthwise(x_temp) * mods[2]
|
||||
except: #operation not implemented for bf16
|
||||
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
||||
x = x + self.depthwise[1](x_temp) * mods[2]
|
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
self.down_blocks[0]
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
def encode(self, x, quantize=False):
|
||||
x = self.in_block(x)
|
||||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe, x, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, quantize=False):
|
||||
qe, x, _, vq_loss = self.encode(x, quantize)
|
||||
x = self.decode(qe)
|
||||
return x, vq_loss
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.encoder(x)
|
||||
if cond is not None:
|
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
256
comfy/ldm/cascade/stage_b.py
Normal file
256
comfy/ldm/cascade/stage_b.py
Normal file
@ -0,0 +1,256 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
|
||||
class StageB(nn.Module):
|
||||
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
|
||||
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
|
||||
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
|
||||
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
|
||||
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip):
|
||||
if len(clip.shape) == 2:
|
||||
clip = clip.unsqueeze(1)
|
||||
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
273
comfy/ldm/cascade/stage_c.py
Normal file
273
comfy/ldm/cascade/stage_c.py
Normal file
@ -0,0 +1,273 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
# from .controlnet import ControlNetDeliverer
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert mode in ['up', 'down']
|
||||
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
|
||||
align_corners=True) if enabled else nn.Identity()
|
||||
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class StageC(nn.Module):
|
||||
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
|
||||
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
|
||||
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
|
||||
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
|
||||
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
if cnet is not None:
|
||||
next_cnet = cnet.pop()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True).to(x.dtype)
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
if cnet is not None:
|
||||
next_cnet = cnet.pop()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True).to(x.dtype)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
|
||||
|
||||
if control is not None:
|
||||
cnet = control.get("input")
|
||||
else:
|
||||
cnet = None
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
def __init__(self, c_latent=16):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
||||
class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.blocks(x) - 0.5) * 2.0
|
||||
|
||||
class StageC_coder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.previewer = Previewer()
|
||||
self.encoder = EfficientNetEncoder()
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.previewer(x)
|
@ -1,6 +1,4 @@
|
||||
import torch
|
||||
# import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -3,9 +3,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
@ -18,13 +19,14 @@ from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# CrossAttn precision handling
|
||||
if args.dont_upcast_attention:
|
||||
print("disabling upcasting of attention")
|
||||
_ATTN_PRECISION = "fp16"
|
||||
else:
|
||||
_ATTN_PRECISION = "fp32"
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||
return attn_precision
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -84,23 +86,35 @@ class FeedForward(nn.Module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
@ -114,7 +128,12 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
sim += mask
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
sim.add_(mask)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
@ -129,18 +148,29 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = query.shape
|
||||
else:
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||
else:
|
||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
@ -165,6 +195,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
@ -182,29 +219,43 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if attn_precision == torch.float32:
|
||||
element_size = 4
|
||||
upcast = True
|
||||
else:
|
||||
element_size = q.element_size()
|
||||
upcast = False
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||
@ -223,6 +274,13 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
@ -231,7 +289,7 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if upcast:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
@ -255,12 +313,12 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
logging.warning("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||
else:
|
||||
raise e
|
||||
|
||||
@ -277,26 +335,41 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
BROKEN_XFORMERS = False
|
||||
try:
|
||||
x_vers = xformers.__version__
|
||||
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
|
||||
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
disabled_xformers = True
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
if not disabled_xformers:
|
||||
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
@ -306,21 +379,30 @@ def attention_xformers(q, k, v, heads, mask=None):
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if skip_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
@ -332,17 +414,17 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
logging.info("Using xformers cross attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
logging.info("Using pytorch cross attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
print("Using split optimization for cross attention")
|
||||
logging.info("Using split optimization for cross attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
@ -364,10 +446,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
@ -389,15 +472,15 @@ class CrossAttention(nn.Module):
|
||||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
@ -405,6 +488,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
inner_dim = dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
@ -412,7 +496,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
@ -426,20 +510,16 @@ class BasicTransformerBlock(nn.Module):
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
extra_options = {}
|
||||
block = transformer_options.get("block", None)
|
||||
block_index = transformer_options.get("block_index", 0)
|
||||
@ -456,6 +536,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
extra_options["attn_precision"] = self.attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
@ -566,7 +647,7 @@ class SpatialTransformer(nn.Module):
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
use_checkpoint=True, dtype=None, device=None, operations=ops):
|
||||
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim] * depth
|
||||
@ -584,7 +665,7 @@ class SpatialTransformer(nn.Module):
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
@ -605,7 +686,7 @@ class SpatialTransformer(nn.Module):
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
x = x.movedim(1, 3).flatten(1, 2).contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
@ -613,7 +694,7 @@ class SpatialTransformer(nn.Module):
|
||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
@ -640,6 +721,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
attn_precision=None,
|
||||
dtype=None, device=None, operations=ops
|
||||
):
|
||||
super().__init__(
|
||||
@ -652,6 +734,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
@ -681,6 +764,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
inner_dim=time_mix_inner_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
|
978
comfy/ldm/modules/diffusionmodules/mmdit.py
Normal file
978
comfy/ldm/modules/diffusionmodules/mmdit.py
Normal file
@ -0,0 +1,978 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
return x
|
||||
return y
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
use_conv=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
drop_probs = drop
|
||||
linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs)
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
self.drop2 = nn.Dropout(drop_probs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
dynamic_img_pad: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
# elif not self.dynamic_img_pad:
|
||||
# _assert(
|
||||
# H % self.patch_size[0] == 0,
|
||||
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
# )
|
||||
# _assert(
|
||||
# W % self.patch_size[1] == 0,
|
||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
# )
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
extra_tokens=0,
|
||||
scaling_factor=None,
|
||||
offset=None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
|
||||
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
|
||||
small = min(h, w)
|
||||
val_h = (h / small) * val_magnitude
|
||||
val_w = (w / small) * val_magnitude
|
||||
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
/ half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds a flat vector of dimension input_dim
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.mlp(x)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
proj_drop: float = 0.0,
|
||||
attn_mode: str = "xformers",
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
self.ln_k = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = nn.Identity()
|
||||
self.ln_k = nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
qkv, num_heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with gated adaptive layer norm (adaLN) conditioning.
|
||||
"""
|
||||
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: str = "xformers",
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if not rmsnorm:
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = operations.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=lambda: nn.GELU(approximate="tanh"),
|
||||
drop=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(
|
||||
dim=hidden_size,
|
||||
hidden_dim=mlp_hidden_dim,
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
(
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
if use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
_block_mixing, *args, use_reentrant=False, **kwargs
|
||||
)
|
||||
else:
|
||||
return _block_mixing(*args, **kwargs)
|
||||
|
||||
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||
qkv = tuple(o)
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
total_out_channels: Optional[int] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = (
|
||||
operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
if (total_out_channels is None)
|
||||
else operations.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class SelfAttentionContext(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.proj = operations.Linear(inner_dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.dim_head)
|
||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||
return self.proj(x)
|
||||
|
||||
class ContextProcessorBlock(nn.Module):
|
||||
def __init__(self, context_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = SelfAttentionContext(context_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm2 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = Mlp(in_features=context_size, hidden_features=(context_size * 4), act_layer=lambda: nn.GELU(approximate="tanh"), drop=0, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
x += self.attn(self.norm1(x))
|
||||
x += self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
class ContextProcessor(nn.Module):
|
||||
def __init__(self, context_size, num_layers, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([ContextProcessorBlock(context_size, dtype=dtype, device=device, operations=operations) for i in range(num_layers)])
|
||||
self.norm = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x)
|
||||
return self.norm(x)
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
# hidden_size: Optional[int] = None,
|
||||
# num_heads: Optional[int] = None,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
compile_core: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
register_length: int = 0,
|
||||
attn_mode: str = "torch",
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches = None,
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.out_channels = default(out_channels, default_out_channels)
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
# num_heads = default(num_heads, hidden_size // 64)
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
self.hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
if num_blocks is None:
|
||||
num_blocks = depth
|
||||
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=self.pos_embed_max_size is None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.y_embedder = None
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(adm_in_channels, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if context_processor_layers is not None:
|
||||
self.context_processor = ContextProcessor(context_size, context_processor_layers, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.context_processor = None
|
||||
|
||||
self.context_embedder = nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = operations.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.empty(1, num_patches, self.hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if compile_core:
|
||||
assert False
|
||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
||||
|
||||
def cropped_pos_embed(self, hw, device=None):
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = (h + 1) // p
|
||||
w = (w + 1) // p
|
||||
if self.pos_embed is None:
|
||||
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
||||
assert self.pos_embed_max_size is not None
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
# print(spatial_pos_embed, top, left, h, w)
|
||||
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
|
||||
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
|
||||
# # print(t)
|
||||
# return t
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x, hw=None):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = (h + 1) // p
|
||||
w = (w + 1) // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||
default(context, torch.Tensor([]).type_as(x)),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
x += add
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
|
||||
if self.context_processor is not None:
|
||||
context = self.context_processor(context)
|
||||
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
|
||||
|
||||
class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
|
@ -3,8 +3,8 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
@ -190,7 +190,7 @@ def slice_attention(q, k, v):
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||
|
||||
return r1
|
||||
|
||||
@ -235,7 +235,7 @@ def pytorch_attention(q, k, v):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
@ -268,13 +268,13 @@ class AttnBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
print("Using xformers attention in VAE")
|
||||
logging.info("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch attention in VAE")
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
print("Using split attention in VAE")
|
||||
logging.info("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
|
||||
def forward(self, x):
|
||||
@ -562,7 +562,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
logging.debug("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
|
@ -4,6 +4,7 @@ import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import logging
|
||||
|
||||
from .util import (
|
||||
checkpoint,
|
||||
@ -257,7 +258,7 @@ class ResBlock(TimestepBlock):
|
||||
else:
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
emb_out = emb_out.movedim(1, 2)
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
@ -359,7 +360,7 @@ def apply_control(h, control, name):
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
print("warning control could not be applied", h.shape, ctrl.shape)
|
||||
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||
return h
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
@ -430,6 +431,7 @@ class UNetModel(nn.Module):
|
||||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
@ -484,7 +486,6 @@ class UNetModel(nn.Module):
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
self.default_num_video_frames = None
|
||||
self.default_image_only_indicator = None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
@ -497,7 +498,7 @@ class UNetModel(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
logging.debug("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
attn_precision=attn_precision,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
@ -708,27 +710,30 @@ class UNetModel(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [get_attention_layer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||
),
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
|
||||
self.middle_block = None
|
||||
if transformer_depth_middle >= -1:
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [get_attention_layer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||
),
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
@ -827,7 +832,7 @@ class UNetModel(nn.Module):
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", None)
|
||||
time_context = kwargs.get("time_context", None)
|
||||
|
||||
assert (y is not None) == (
|
||||
@ -858,7 +863,8 @@ class UNetModel(nn.Module):
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
if self.middle_block is not None:
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
|
||||
|
@ -46,23 +46,25 @@ class AlphaBlender(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
|
||||
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||
if self.merge_strategy == "fixed":
|
||||
# make shape compatible
|
||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||
alpha = self.mix_factor.to(image_only_indicator.device)
|
||||
alpha = self.mix_factor.to(device)
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
|
||||
alpha = torch.sigmoid(self.mix_factor.to(device))
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
|
||||
)
|
||||
if image_only_indicator is None:
|
||||
alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
|
||||
else:
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
@ -76,7 +78,7 @@ class AlphaBlender(nn.Module):
|
||||
x_temporal,
|
||||
image_only_indicator=None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
|
@ -14,6 +14,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
import logging
|
||||
|
||||
try:
|
||||
from typing import Optional, NamedTuple, List, Protocol
|
||||
@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@ -20,8 +21,16 @@ def load_lora(lora, to_load):
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
dora_scale_name = "{}.dora_scale".format(x)
|
||||
dora_scale = None
|
||||
if dora_scale_name in lora.keys():
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
@ -33,6 +42,14 @@ def load_lora(lora, to_load):
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers2_lora in lora.keys():
|
||||
A_name = diffusers2_lora
|
||||
B_name = "{}.lora_A.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers3_lora in lora.keys():
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
@ -43,7 +60,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@ -64,7 +81,7 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
@ -116,7 +133,7 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||
|
||||
#glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
@ -124,7 +141,7 @@ def load_lora(lora, to_load):
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
@ -156,7 +173,8 @@ def load_lora(lora, to_load):
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
print("lora key not loaded", x)
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
@ -197,16 +215,36 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk: #OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
||||
|
||||
k = "clip_l.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
||||
|
||||
return key_map
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
sd = model.state_dict()
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@ -221,4 +259,19 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
if diffusers_lora_key.endswith(".to_out.0"):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
|
@ -1,20 +1,32 @@
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import math
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
STABLE_CASCADE = 4
|
||||
EDM = 5
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -27,6 +39,18 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -35,7 +59,7 @@ def model_sampling(model_config, model_type):
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
@ -48,16 +72,20 @@ class BaseModel(torch.nn.Module):
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
self.inpaint_model = False
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
self.concat_keys = ()
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
@ -96,8 +124,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
concat_latent_image = kwargs.get("concat_latent_image", None)
|
||||
@ -114,24 +141,16 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
||||
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
if denoise_mask is not None:
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
for ck in concat_keys:
|
||||
for ck in self.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask.to(device))
|
||||
@ -141,7 +160,7 @@ class BaseModel(torch.nn.Module):
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
|
||||
@ -157,6 +176,10 @@ class BaseModel(torch.nn.Module):
|
||||
if cross_attn_cnet is not None:
|
||||
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
|
||||
|
||||
c_concat = kwargs.get("noise_concat", None)
|
||||
if c_concat is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(c_concat)
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
@ -169,10 +192,10 @@ class BaseModel(torch.nn.Module):
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
print("unet missing:", m)
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
print("unet unexpected:", u)
|
||||
logging.warning("unet unexpected: {}".format(u))
|
||||
del to_load
|
||||
return self
|
||||
|
||||
@ -194,9 +217,6 @@ class BaseModel(torch.nn.Module):
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.get_dtype() == torch.float16:
|
||||
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@ -206,7 +226,16 @@ class BaseModel(torch.nn.Module):
|
||||
return unet_state_dict
|
||||
|
||||
def set_inpaint(self):
|
||||
self.inpaint_model = True
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
@ -214,11 +243,11 @@ class BaseModel(torch.nn.Module):
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
|
||||
@ -362,10 +391,39 @@ class SVD_img2vid(BaseModel):
|
||||
if "time_conditioning" in kwargs:
|
||||
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||
|
||||
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
||||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
class SV3D_p(SVD_img2vid):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder_512 = Timestep(512)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
|
||||
azimuth = kwargs.get("azimuth", 0)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
|
||||
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
|
||||
|
||||
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
|
||||
return torch.cat(out, dim=1)
|
||||
|
||||
|
||||
class Stable_Zero123(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -427,3 +485,154 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
if model_type == ModelType.V_PREDICTION_EDM:
|
||||
self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
|
||||
else:
|
||||
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
||||
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
if "unclip_conditioning" in kwargs:
|
||||
embeds = []
|
||||
for unclip_cond in kwargs["unclip_conditioning"]:
|
||||
weight = unclip_cond["strength"]
|
||||
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
|
||||
clip_img = torch.cat(embeds, dim=1)
|
||||
else:
|
||||
clip_img = torch.zeros((1, 1, 768))
|
||||
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
return out
|
||||
|
||||
|
||||
class SD3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this probably needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
||||
else:
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * 0.3) * (1024 * 1024)
|
||||
|
||||
|
||||
class StableAudio1(BaseModel):
|
||||
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
||||
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
seconds_start = kwargs.get("seconds_start", 0)
|
||||
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
|
||||
|
||||
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
|
||||
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
|
||||
|
||||
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
|
||||
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
for l in s:
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
@ -1,5 +1,9 @@
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
@ -25,12 +29,82 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
def detect_unet_config(state_dict, key_prefix):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||
unet_config = {}
|
||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||
unet_config["patch_size"] = patch_size
|
||||
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
|
||||
if final_layer in state_dict:
|
||||
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
|
||||
|
||||
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||
unet_config["input_size"] = None
|
||||
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
||||
|
||||
context_key = '{}context_embedder.weight'.format(key_prefix)
|
||||
if context_key in state_dict_keys:
|
||||
in_features = state_dict[context_key].shape[1]
|
||||
out_features = state_dict[context_key].shape[0]
|
||||
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
||||
num_patches_key = '{}pos_embed'.format(key_prefix)
|
||||
if num_patches_key in state_dict_keys:
|
||||
num_patches = state_dict[num_patches_key].shape[1]
|
||||
unet_config["num_patches"] = num_patches
|
||||
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
||||
|
||||
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
|
||||
if rms_qk in state_dict_keys:
|
||||
unet_config["qk_norm"] = "rms"
|
||||
|
||||
unet_config["pos_embed_scaling_factor"] = None #unused for inference
|
||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||
if context_processor in state_dict_keys:
|
||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||
return unet_config
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
unet_config = {}
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: #stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: #stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
||||
if w.shape[-1] == 640:
|
||||
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
||||
unet_config['nhead'] = [-1, -1, 20, 20]
|
||||
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
||||
elif w.shape[-1] == 576: #stage b lite
|
||||
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
return unet_config
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
@ -45,7 +119,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@ -64,6 +137,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
use_linear_in_transformer = False
|
||||
|
||||
video_model = False
|
||||
video_model_cross = False
|
||||
|
||||
current_res = 1
|
||||
count = 0
|
||||
@ -107,6 +181,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
context_dim = out[1]
|
||||
use_linear_in_transformer = out[2]
|
||||
video_model = out[3]
|
||||
video_model_cross = out[4]
|
||||
else:
|
||||
transformer_depth.append(0)
|
||||
|
||||
@ -123,8 +198,10 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
channel_mult.append(last_channel_mult)
|
||||
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||
else:
|
||||
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
|
||||
transformer_depth_middle = -1
|
||||
else:
|
||||
transformer_depth_middle = -2
|
||||
|
||||
unet_config["in_channels"] = in_channels
|
||||
unet_config["out_channels"] = out_channels
|
||||
@ -145,28 +222,36 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||
unet_config["use_temporal_resblock"] = True
|
||||
unet_config["use_temporal_attention"] = True
|
||||
unet_config["disable_temporal_crossattention"] = not video_model_cross
|
||||
else:
|
||||
unet_config["use_temporal_resblock"] = False
|
||||
unet_config["use_temporal_attention"] = False
|
||||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config):
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
for model_config in comfy.supported_models.models:
|
||||
if model_config.matches(unet_config):
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
return model_config(unet_config)
|
||||
|
||||
print("no match", unet_config)
|
||||
logging.error("no match {}".format(unet_config))
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||
unet_key_prefix = "model.model."
|
||||
else:
|
||||
unet_key_prefix = "model.diffusion_model."
|
||||
return unet_key_prefix
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||
@ -206,7 +291,7 @@ def convert_config(unet_config):
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
@ -214,6 +299,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
||||
for i in range(down_blocks):
|
||||
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
||||
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
|
||||
for ab in range(attn_blocks):
|
||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
@ -222,8 +308,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
transformer_depth.append(0)
|
||||
transformer_depth.append(0)
|
||||
for i in range(res_blocks):
|
||||
transformer_depth.append(0)
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
@ -289,6 +375,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||
@ -301,7 +393,32 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
|
||||
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
||||
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
|
||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||
|
||||
SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
@ -313,8 +430,44 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
def model_config_from_diffusers_unet(state_dict):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
if num_blocks > 0:
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
out_sd = {}
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
t = sd_map[k]
|
||||
|
||||
if not isinstance(t, str):
|
||||
if len(t) > 2:
|
||||
fun = t[2]
|
||||
else:
|
||||
fun = lambda a: a
|
||||
offset = t[1]
|
||||
if offset is not None:
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
t = t[0]
|
||||
out_sd[t] = old_weight
|
||||
else:
|
||||
out_sd[t] = weight
|
||||
state_dict.pop(k)
|
||||
|
||||
return out_sd
|
||||
|
@ -1,10 +1,12 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import torch
|
||||
import sys
|
||||
import os.path
|
||||
import platform
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -30,7 +32,7 @@ lowvram_available = True
|
||||
xpu_available = False
|
||||
|
||||
if args.deterministic:
|
||||
print("Using deterministic algorithms for pytorch")
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_enabled = False
|
||||
@ -42,7 +44,7 @@ if args.directml is not None:
|
||||
directml_device = torch_directml.device()
|
||||
else:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
@ -83,7 +85,7 @@ def get_torch_device():
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if is_intel_xpu():
|
||||
return torch.device("xpu")
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
@ -120,8 +122,8 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -135,12 +137,13 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
return mem_total
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = get_total_memory(torch.device("cpu")) / (1024 * 1024)
|
||||
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
if not args.normalvram and not args.cpu:
|
||||
if lowvram_available and total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||
@ -162,12 +165,10 @@ else:
|
||||
pass
|
||||
try:
|
||||
XFORMERS_VERSION = xformers.version.__version__
|
||||
print("xformers version:", XFORMERS_VERSION)
|
||||
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
||||
if XFORMERS_VERSION.startswith("0.0.18"):
|
||||
print()
|
||||
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||
print("Please downgrade or upgrade xformers to a different version.")
|
||||
print()
|
||||
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||
XFORMERS_ENABLED_VAE = False
|
||||
except:
|
||||
pass
|
||||
@ -186,7 +187,7 @@ if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
VAE_DTYPE = torch.float32
|
||||
VAE_DTYPES = [torch.float32]
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
@ -195,7 +196,7 @@ try:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||
if is_intel_xpu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -203,17 +204,10 @@ except:
|
||||
pass
|
||||
|
||||
if is_intel_xpu():
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||
|
||||
if args.cpu_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
if args.fp16_vae:
|
||||
VAE_DTYPE = torch.float16
|
||||
elif args.bf16_vae:
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
elif args.fp32_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
VAE_DTYPES = [torch.float32]
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
@ -232,11 +226,11 @@ elif args.highvram or args.gpu_only:
|
||||
FORCE_FP32 = False
|
||||
FORCE_FP16 = False
|
||||
if args.force_fp32:
|
||||
print("Forcing FP32, if this improves things please report it.")
|
||||
logging.info("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
if args.force_fp16:
|
||||
print("Forcing FP16.")
|
||||
logging.info("Forcing FP16.")
|
||||
FORCE_FP16 = True
|
||||
|
||||
if lowvram_available:
|
||||
@ -250,12 +244,12 @@ if cpu_state != CPUState.GPU:
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
print(f"Set vram state to: {vram_state.name}")
|
||||
logging.info(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||
|
||||
if DISABLE_SMART_MEMORY:
|
||||
print("Disabling smart memory management")
|
||||
logging.info("Disabling smart memory management")
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
@ -273,11 +267,10 @@ def get_torch_device_name(device):
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
try:
|
||||
print("Device:", get_torch_device_name(get_torch_device()))
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
print("VAE dtype:", VAE_DTYPE)
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
@ -292,8 +285,10 @@ def module_size(module):
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.model_accelerated = False
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
self.real_model = None
|
||||
self.currently_used = True
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@ -304,55 +299,40 @@ class LoadedModel:
|
||||
else:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0):
|
||||
patch_model_to = None
|
||||
if lowvram_model_memory == 0:
|
||||
patch_model_to = self.device
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if lowvram_model_memory > 0:
|
||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
||||
mem_counter = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
module_mem = module_size(m)
|
||||
if mem_counter + module_mem < lowvram_model_memory:
|
||||
m.to(self.device)
|
||||
mem_counter += module_mem
|
||||
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
print("lowvram: loaded module regularly", m)
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
|
||||
def model_unload(self):
|
||||
if self.model_accelerated:
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
self.model_accelerated = False
|
||||
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
def model_unload(self, unpatch_weights=True):
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
self.real_model = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
@ -360,31 +340,58 @@ class LoadedModel:
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
|
||||
def unload_model_clones(model):
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
print("unload clone", i)
|
||||
current_loaded_models.pop(i).model_unload()
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
unloaded_model = False
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
if get_free_memory(device) > memory_required:
|
||||
break
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
m = current_loaded_models.pop(i)
|
||||
m.model_unload()
|
||||
del m
|
||||
unloaded_model = True
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
if unloaded_model:
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
if get_free_memory(device) > memory_required:
|
||||
break
|
||||
current_loaded_models[i].model_unload()
|
||||
unloaded_model.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
current_loaded_models.pop(i)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
else:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
@ -392,24 +399,37 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
|
||||
models = set(models)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
if loaded_model in current_loaded_models:
|
||||
index = current_loaded_models.index(loaded_model)
|
||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except:
|
||||
loaded_model_index = None
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
|
||||
if loaded is None:
|
||||
if hasattr(x, "model"):
|
||||
print(f"Requested to load {x.model.__class__.__name__}")
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
@ -419,17 +439,22 @@ def load_models_gpu(models, memory_required=0):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
|
||||
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model)
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
@ -442,15 +467,13 @@ def load_models_gpu(models, memory_required=0):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
else:
|
||||
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
@ -458,11 +481,25 @@ def load_models_gpu(models, memory_required=0):
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
def cleanup_models():
|
||||
def loaded_models(only_currently_used=False):
|
||||
output = []
|
||||
for m in current_loaded_models:
|
||||
if only_currently_used:
|
||||
if not m.currently_used:
|
||||
continue
|
||||
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
to_delete = [i] + to_delete
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
@ -506,7 +543,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@ -516,20 +553,31 @@ def unet_dtype(device=None, model_params=0):
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
return torch.float16
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device):
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
if fp16_supported:
|
||||
bf16_supported = should_use_bf16(inference_device)
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
@ -543,8 +591,6 @@ def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||
if is_intel_xpu():
|
||||
return torch.device("cpu")
|
||||
if should_use_fp16(prioritize_performance=False):
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -585,9 +631,22 @@ def vae_offload_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def vae_dtype():
|
||||
global VAE_DTYPE
|
||||
return VAE_DTYPE
|
||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
global VAE_DTYPES
|
||||
if args.fp16_vae:
|
||||
return torch.float16
|
||||
elif args.bf16_vae:
|
||||
return torch.bfloat16
|
||||
elif args.fp32_vae:
|
||||
return torch.float32
|
||||
|
||||
for d in allowed_dtypes:
|
||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
||||
return d
|
||||
if d in VAE_DTYPES:
|
||||
return d
|
||||
|
||||
return VAE_DTYPES[0]
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
@ -605,11 +664,47 @@ def supports_dtype(device, dtype): #TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
def supports_cast(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_enabled: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return True
|
||||
if dtype == torch.float8_e5m2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
|
||||
#TODO
|
||||
return False
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
@ -620,7 +715,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
non_blocking = device_should_use_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
@ -661,8 +756,22 @@ def pytorch_attention_flash_attention():
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||
return True
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
if upcast:
|
||||
return torch.float32
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -684,10 +793,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_allocated = stats['allocated_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@ -709,17 +818,20 @@ def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def is_device_cpu(device):
|
||||
def is_device_type(device, type):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu'):
|
||||
if (device.type == type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_device_cpu(device):
|
||||
return is_device_type(device, 'cpu')
|
||||
|
||||
def is_device_mps(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'mps'):
|
||||
return True
|
||||
return False
|
||||
return is_device_type(device, 'mps')
|
||||
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
@ -732,9 +844,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP16:
|
||||
return True
|
||||
|
||||
if device is not None: #TODO
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
@ -742,8 +854,11 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
return False #TODO ?
|
||||
if mps_mode():
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
@ -762,7 +877,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
#when the model doesn't actually fit on the card
|
||||
#TODO: actually test if GP106 and others have the same type of behavior
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
@ -783,6 +898,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return True
|
||||
|
||||
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if device is not None:
|
||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None: #TODO not sure about mps bf16 support
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@ -799,6 +951,7 @@ def unload_all_models():
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
|
@ -1,9 +1,61 @@
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.types import UnetWrapperFunction
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
else:
|
||||
to["patches_replace"] = to["patches_replace"].copy()
|
||||
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
else:
|
||||
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
||||
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
model_options["transformer_options"] = to
|
||||
return model_options
|
||||
|
||||
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
||||
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
@ -23,13 +75,14 @@ class ModelPatcher:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.lowvram_patch_counter = 0
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
self.model_keys = set(model_sd.keys())
|
||||
return self.size
|
||||
|
||||
def clone(self):
|
||||
@ -37,10 +90,12 @@ class ModelPatcher:
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
n.patches_uuid = self.patches_uuid
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
@ -48,6 +103,19 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if len(self.patches) == 0 and len(clone.patches) == 0:
|
||||
return True
|
||||
|
||||
if self.patches_uuid == clone.patches_uuid:
|
||||
if len(self.patches) != len(clone.patches):
|
||||
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
||||
else:
|
||||
return True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
@ -60,13 +128,14 @@ class ModelPatcher:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function):
|
||||
self.model_options["denoise_mask_function"] = denoise_mask_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
@ -74,16 +143,7 @@ class ModelPatcher:
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
@ -115,6 +175,15 @@ class ModelPatcher:
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def get_model_object(self, name):
|
||||
if name in self.object_patches:
|
||||
return self.object_patches[name]
|
||||
else:
|
||||
if name in self.object_patches_backup:
|
||||
return self.object_patches_backup[name]
|
||||
else:
|
||||
return comfy.utils.get_attr(self.model, name)
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
@ -142,13 +211,25 @@ class ModelPatcher:
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
model_sd = self.model.state_dict()
|
||||
for k in patches:
|
||||
if k in self.model_keys:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(k, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model))
|
||||
self.patches[k] = current_patches
|
||||
offset = None
|
||||
function = None
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
@ -174,37 +255,41 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = getattr(self.model, k)
|
||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
setattr(self.model, k, self.object_patches[k])
|
||||
|
||||
if patch_weights:
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", key)
|
||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
self.patch_weight_to_device(key, device_to)
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
@ -212,11 +297,71 @@ class ModelPatcher:
|
||||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
for n, m in self.model.named_modules():
|
||||
lowvram_weight = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if lowvram_weight:
|
||||
if weight_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
else:
|
||||
if hasattr(m, "weight"):
|
||||
self.patch_weight_to_device(weight_key, device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to)
|
||||
m.to(device_to)
|
||||
mem_counter += comfy.model_management.module_size(m)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
|
||||
self.model_lowvram = True
|
||||
self.lowvram_patch_counter = patch_counter
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
@ -232,25 +377,33 @@ class ModelPatcher:
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if strength != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
@ -259,6 +412,7 @@ class ModelPatcher:
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
@ -284,19 +438,29 @@ class ModelPatcher:
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha *= v[2] / dim
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / w1b.shape[0]
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
@ -316,42 +480,72 @@ class ModelPatcher:
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||
|
||||
try:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha *= v[4] / v[0].shape[0]
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
print("patch type not recognized", patch_type, key)
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
||||
self.model_lowvram = False
|
||||
self.lowvram_patch_counter = 0
|
||||
|
||||
self.backup = {}
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
||||
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
setattr(self.model, k, self.object_patches_backup[k])
|
||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
self.object_patches_backup.clear()
|
||||
|
@ -11,12 +11,41 @@ class EPS:
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
noise = noise * sigma
|
||||
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class EDM(V_PREDICTION):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class CONST:
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
@ -88,12 +117,16 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
class ModelSamplingDiscreteEDM(ModelSamplingDiscrete):
|
||||
def timestep(self, sigma):
|
||||
return 0.25 * sigma.log()
|
||||
|
||||
def sigma(self, timestep):
|
||||
return (timestep / 0.25).exp()
|
||||
|
||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
@ -101,9 +134,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
|
||||
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
||||
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
||||
self.set_sigma_range(sigma_min, sigma_max)
|
||||
sigma_data = sampling_settings.get("sigma_data", 1.0)
|
||||
self.set_parameters(sigma_min, sigma_max, sigma_data)
|
||||
|
||||
def set_sigma_range(self, sigma_min, sigma_max):
|
||||
def set_parameters(self, sigma_min, sigma_max, sigma_data):
|
||||
self.sigma_data = sigma_data
|
||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
||||
|
||||
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
||||
@ -132,3 +167,107 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
|
||||
log_sigma_min = math.log(self.sigma_min)
|
||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||
|
||||
|
||||
class ModelSamplingContinuousV(ModelSamplingContinuousEDM):
|
||||
def timestep(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def sigma(self, timestep):
|
||||
return (timestep * math.pi / 2).tan()
|
||||
|
||||
|
||||
def time_snr_shift(alpha, t):
|
||||
if alpha == 1.0:
|
||||
return t
|
||||
return alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
|
||||
|
||||
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
|
||||
self.shift = shift
|
||||
self.multiplier = multiplier
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * self.multiplier
|
||||
|
||||
def sigma(self, timestep):
|
||||
return time_snr_shift(self.shift, timestep / self.multiplier)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(sampling_settings.get("shift", 1.0))
|
||||
|
||||
def set_parameters(self, shift=1.0, cosine_s=8e-3):
|
||||
self.shift = shift
|
||||
self.cosine_s = torch.tensor(cosine_s)
|
||||
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
||||
|
||||
#This part is just for compatibility with some schedulers in the codebase
|
||||
self.num_timesteps = 10000
|
||||
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
||||
for x in range(self.num_timesteps):
|
||||
t = (x + 1) / self.num_timesteps
|
||||
sigmas[x] = self.sigma(t)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def sigma(self, timestep):
|
||||
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
|
||||
|
||||
if self.shift != 1.0:
|
||||
var = alpha_cumprod
|
||||
logSNR = (var/(1-var)).log()
|
||||
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
|
||||
alpha_cumprod = logSNR.sigmoid()
|
||||
|
||||
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
|
||||
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
||||
|
||||
def timestep(self, sigma):
|
||||
var = 1 / ((sigma * sigma) + 1)
|
||||
var = var.clamp(0, 1.0)
|
||||
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
|
||||
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return t
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent))
|
||||
|
114
comfy/ops.py
114
comfy/ops.py
@ -1,18 +1,43 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
bias = s.bias_function(bias)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear):
|
||||
comfy_cast_weights = False
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -26,8 +51,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
comfy_cast_weights = False
|
||||
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -41,8 +65,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
comfy_cast_weights = False
|
||||
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -56,8 +79,21 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
comfy_cast_weights = False
|
||||
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -72,13 +108,16 @@ class disable_weight_init:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
comfy_cast_weights = False
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -87,6 +126,48 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||
num_spatial_dims = 2
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||
num_spatial_dims = 1
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
@ -101,6 +182,9 @@ class manual_cast(disable_weight_init):
|
||||
class Linear(disable_weight_init.Linear):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Conv1d(disable_weight_init.Conv1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Conv2d(disable_weight_init.Conv2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
@ -112,3 +196,9 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class LayerNorm(disable_weight_init.LayerNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
22
comfy/sa_t5.py
Normal file
22
comfy/sa_t5.py
Normal file
@ -0,0 +1,22 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import os
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||
|
||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
@ -1,10 +1,9 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
@ -25,94 +24,27 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
for m in control_nets:
|
||||
control_models += m.get_models()
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
|
||||
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
|
||||
return latent_image
|
||||
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
device = model.load_device
|
||||
positive = convert_cond(positive)
|
||||
negative = convert_cond(negative)
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
||||
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, positive, negative, noise_mask, models
|
||||
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
|
||||
return model, positive, negative, noise_mask, []
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
noise = noise.to(model.load_device)
|
||||
latent_image = latent_image.to(model.load_device)
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
||||
noise = noise.to(model.load_device)
|
||||
latent_image = latent_image.to(model.load_device)
|
||||
sigmas = sigmas.to(model.load_device)
|
||||
|
||||
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
return samples
|
||||
|
||||
|
76
comfy/sampler_helpers.py
Normal file
76
comfy/sampler_helpers.py
Normal file
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
gligen = []
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
for m in control_nets:
|
||||
control_models += m.get_models()
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model, noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
def cleanup_models(conds, models):
|
||||
cleanup_additional_models(models)
|
||||
|
||||
control_cleanup = []
|
||||
for k in conds:
|
||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
@ -4,9 +4,12 @@ import torch
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
dims = tuple(x_in.shape[2:])
|
||||
area = None
|
||||
strength = 1.0
|
||||
|
||||
if 'timestep_start' in conds:
|
||||
@ -18,11 +21,16 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in conds:
|
||||
area = conds['area']
|
||||
area = list(conds['area'])
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
input_x = x_in
|
||||
if area is not None:
|
||||
for i in range(len(dims)):
|
||||
area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i])
|
||||
input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i])
|
||||
|
||||
if 'mask' in conds:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
@ -30,28 +38,30 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
assert(mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
for i in range(len(dims)):
|
||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||
|
||||
mask = mask * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in conds:
|
||||
if 'mask' not in conds and area is not None:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||
if (area[0] + area[2]) < x_in.shape[2]:
|
||||
for t in range(rr):
|
||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||
if area[3] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
for i in range(len(dims)):
|
||||
if area[len(dims) + i] != 0:
|
||||
for t in range(rr):
|
||||
m = mult.narrow(i + 2, t, 1)
|
||||
m *= ((1.0/rr) * (t + 1))
|
||||
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
||||
for t in range(rr):
|
||||
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
||||
m *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
@ -126,30 +136,23 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
to_run += [(p, UNCOND)]
|
||||
cond = conds[i]
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, i)]
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
@ -208,6 +211,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
@ -220,71 +224,77 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
del mult
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
out_cond /= out_count
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
|
||||
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
return cfg_result
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
conds = [cond, uncond_]
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
return cfg_result
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
class KSamplerX0Inpaint:
|
||||
def __init__(self, model, sigmas):
|
||||
self.inner_model = model
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||
self.sigmas = sigmas
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
if denoise_mask is not None:
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
def simple_scheduler(model_sampling, steps):
|
||||
s = model_sampling
|
||||
sigs = []
|
||||
ss = len(s.sigmas) / steps
|
||||
for x in range(steps):
|
||||
@ -292,10 +302,10 @@ def simple_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
def ddim_scheduler(model_sampling, steps):
|
||||
s = model_sampling
|
||||
sigs = []
|
||||
ss = len(s.sigmas) // steps
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
x = 1
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
@ -304,8 +314,8 @@ def ddim_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||
s = model.model_sampling
|
||||
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||
s = model_sampling
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
@ -344,7 +354,7 @@ def get_mask_aabb(masks):
|
||||
|
||||
return bounding_boxes, is_empty
|
||||
|
||||
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||
for i in range(len(conditions)):
|
||||
@ -353,7 +363,14 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
area = c['area']
|
||||
if area[0] == "percentage":
|
||||
modified = c.copy()
|
||||
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
||||
a = area[1:]
|
||||
a_len = len(a) // 2
|
||||
area = ()
|
||||
for d in range(len(dims)):
|
||||
area += (max(1, round(a[d] * dims[d])),)
|
||||
for d in range(len(dims)):
|
||||
area += (round(a[d + a_len] * dims[d]),)
|
||||
|
||||
modified['area'] = area
|
||||
c = modified
|
||||
conditions[i] = c
|
||||
@ -362,12 +379,12 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
mask = c['mask']
|
||||
mask = mask.to(device=device)
|
||||
modified = c.copy()
|
||||
if len(mask.shape) == 2:
|
||||
if len(mask.shape) == len(dims):
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != h or mask.shape[2] != w:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||
if mask.shape[1:] != dims:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||
|
||||
if modified.get("set_area_to_bounds", False):
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
boxes, is_empty = get_mask_aabb(bounds)
|
||||
if is_empty[0]:
|
||||
@ -384,7 +401,11 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
modified['mask'] = mask
|
||||
conditions[i] = modified
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
||||
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
|
||||
if 'area' not in c:
|
||||
return
|
||||
|
||||
@ -488,7 +509,10 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
|
||||
params = x.copy()
|
||||
params["device"] = device
|
||||
params["noise"] = noise
|
||||
params["width"] = params.get("width", noise.shape[3] * 8)
|
||||
default_width = None
|
||||
if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model
|
||||
default_width = noise.shape[3] * 8
|
||||
params["width"] = params.get("width", default_width)
|
||||
params["height"] = params.get("height", noise.shape[2] * 8)
|
||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||
for k in kwargs:
|
||||
@ -513,17 +537,10 @@ class Sampler:
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
class UNIPC(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
|
||||
class UNIPCBH2(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
@ -533,7 +550,7 @@ class KSAMPLER(Sampler):
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
|
||||
model_k.latent_image = latent_image
|
||||
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||
@ -541,26 +558,24 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
|
||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
||||
return samples
|
||||
|
||||
|
||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
if sampler_name == "dpm_fast":
|
||||
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
if len(sigmas) <= 1:
|
||||
return noise
|
||||
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
@ -568,81 +583,145 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
|
||||
sampler_function = dpm_fast_function
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
||||
if len(sigmas) <= 1:
|
||||
return noise
|
||||
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
|
||||
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
|
||||
sampler_function = dpm_adaptive_function
|
||||
else:
|
||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
negative = negative[:]
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
|
||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
|
||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
for k in conds:
|
||||
calculate_start_end_timesteps(model, conds[k])
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
for kk in conds:
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
pre_run_control(model, negative + positive)
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
if "positive" in conds:
|
||||
positive = conds["positive"]
|
||||
for k in conds:
|
||||
if k != "positive":
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
return conds
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
|
||||
def set_conds(self, positive, negative):
|
||||
self.inner_set_conds({"positive": positive, "negative": negative})
|
||||
|
||||
def set_cfg(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def inner_set_conds(self, conds):
|
||||
for k in conds:
|
||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict_noise(*args, **kwargs)
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||
|
||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.conds
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
cfg_guider = CFGGuider(model)
|
||||
cfg_guider.set_conds(positive, negative)
|
||||
cfg_guider.set_cfg(cfg)
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
sigmas = normal_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
sigmas = simple_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
else:
|
||||
print("error invalid scheduler", scheduler_name)
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
|
||||
def sampler_object(name):
|
||||
if name == "uni_pc":
|
||||
sampler = UNIPC()
|
||||
sampler = KSAMPLER(uni_pc.sample_unipc)
|
||||
elif name == "uni_pc_bh2":
|
||||
sampler = UNIPCBH2()
|
||||
sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
|
||||
elif name == "ddim":
|
||||
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||
else:
|
||||
@ -652,6 +731,7 @@ def sampler_object(name):
|
||||
class KSampler:
|
||||
SCHEDULERS = SCHEDULER_NAMES
|
||||
SAMPLERS = SAMPLER_NAMES
|
||||
DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
@ -670,11 +750,11 @@ class KSampler:
|
||||
sigmas = None
|
||||
|
||||
discard_penultimate_sigma = False
|
||||
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
|
||||
if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
|
||||
steps += 1
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
|
||||
sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
|
||||
|
||||
if discard_penultimate_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
@ -685,9 +765,12 @@ class KSampler:
|
||||
if denoise is None or denoise > 0.9999:
|
||||
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
if denoise <= 0.0:
|
||||
self.sigmas = torch.FloatTensor([])
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas is None:
|
||||
|
384
comfy/sd.py
384
comfy/sd.py
@ -1,7 +1,12 @@
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@ -9,12 +14,13 @@ import comfy.utils
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
from . import diffusers_convert
|
||||
from . import model_base
|
||||
from . import model_detection
|
||||
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -33,7 +39,7 @@ def load_model_weights(model, sd):
|
||||
w = sd.pop(x)
|
||||
del w
|
||||
if len(m) > 0:
|
||||
print("missing", m)
|
||||
logging.warning("missing {}".format(m))
|
||||
return model
|
||||
|
||||
def load_clip_weights(model, sd):
|
||||
@ -48,7 +54,7 @@ def load_clip_weights(model, sd):
|
||||
if ids.dtype == torch.float32:
|
||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
|
||||
return load_model_weights(model, sd)
|
||||
|
||||
|
||||
@ -77,7 +83,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print("NOT LOADED", x)
|
||||
logging.warning("NOT LOADED {}".format(x))
|
||||
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
@ -93,13 +99,19 @@ class CLIP:
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
params['dtype'] = dtype
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@ -119,10 +131,13 @@ class CLIP:
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
else:
|
||||
self.cond_stage_model.reset_clip_layer()
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
|
||||
if return_pooled == "unprojected":
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
@ -134,8 +149,11 @@ class CLIP:
|
||||
tokens = self.tokenize(text)
|
||||
return self.encode_from_tokens(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
@ -155,7 +173,12 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.output_channels = 3
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@ -167,38 +190,99 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
||||
else:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
#TODO
|
||||
#self.memory_used_encode
|
||||
#self.memory_used_decode
|
||||
self.process_input = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
if 'quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
self.first_stage_model = AudioOobleckVAE()
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
return
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
print("Missing VAE keys", m)
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
print("Leftover VAE keys", u)
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
offload_device = model_management.vae_offload_device()
|
||||
if dtype is None:
|
||||
dtype = model_management.vae_dtype()
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
|
||||
x_offset = (dims[d] % self.downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
@ -206,27 +290,35 @@ class VAE:
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
@ -235,13 +327,16 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
if len(samples_in.shape) == 3:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -252,6 +347,7 @@ class VAE:
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
@ -259,18 +355,22 @@ class VAE:
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
if len(pixel_samples.shape) == 3:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
@ -297,8 +397,13 @@ def load_style_model(ckpt_path):
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
SD3 = 3
|
||||
STABLE_AUDIO = 4
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
@ -308,32 +413,55 @@ def load_clip(ckpt_paths, embedding_directory=None):
|
||||
|
||||
for i in range(len(clip_data)):
|
||||
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
|
||||
clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32)
|
||||
clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "")
|
||||
else:
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
clip_target.clip = sa_t5.SAT5Model
|
||||
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
print("clip missing:", m)
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
print("clip unexpected:", u)
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
@ -344,6 +472,8 @@ def load_gligen(ckpt_path):
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||
#TODO: this function is a mess and should be removed eventually
|
||||
if config is None:
|
||||
with open(config_path, 'r') as stream:
|
||||
@ -351,81 +481,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
|
||||
fp16 = False
|
||||
if "unet_config" in model_config_params:
|
||||
if "params" in model_config_params["unet_config"]:
|
||||
unet_config = model_config_params["unet_config"]["params"]
|
||||
if "use_fp16" in unet_config:
|
||||
fp16 = unet_config.pop("use_fp16")
|
||||
if fp16:
|
||||
unet_config["dtype"] = torch.float16
|
||||
|
||||
noise_aug_config = None
|
||||
if "noise_aug_config" in model_config_params:
|
||||
noise_aug_config = model_config_params["noise_aug_config"]
|
||||
|
||||
model_type = model_base.ModelType.EPS
|
||||
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
model_type = model_base.ModelType.V_PREDICTION
|
||||
m = model.clone()
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.model_sampling.V_PREDICTION):
|
||||
pass
|
||||
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
|
||||
model = m
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
layer_idx = clip_config.get("params", {}).get("layer_idx", None)
|
||||
if layer_idx is not None:
|
||||
clip.clip_layer(layer_idx)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = comfy.utils.load_torch_file(ckpt_path)
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
model_config = comfy.supported_models_base.BASE({})
|
||||
|
||||
from . import latent_formats
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
model_config.unet_config = model_detection.convert_config(unet_config)
|
||||
|
||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||
else:
|
||||
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model.set_inpaint()
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae = VAE(sd=vae_sd, config=vae_config)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = clip_config.get("params", {})
|
||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_h
|
||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_l
|
||||
load_clip_weights(w, state_dict)
|
||||
|
||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
@ -437,16 +506,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model_patcher = None
|
||||
clip_target = None
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
@ -458,8 +525,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
@ -467,41 +534,65 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
clip_target = model_config.clip_target()
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
sd = model_config.process_clip_state_dict(sd)
|
||||
load_model_weights(w, sd)
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
print("left over keys:", left_over)
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
logging.info("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
checkpoint = False
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
checkpoint = True
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
new_sd = sd
|
||||
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
if new_sd is None:
|
||||
return None
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
@ -512,33 +603,44 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
print(diffusers_keys[k], k)
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
print("left over keys in unet:", left_over)
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd)
|
||||
if model is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models)
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
if not t.is_contiguous():
|
||||
sd[k] = t.contiguous()
|
||||
|
||||
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
|
@ -8,6 +8,8 @@ import zipfile
|
||||
from . import model_management
|
||||
import comfy.clip_model
|
||||
import json
|
||||
import logging
|
||||
import numbers
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
@ -67,7 +69,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
@ -86,16 +89,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
self.enable_attention_masks = enable_attention_masks
|
||||
self.zero_out_masked = zero_out_masked
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
@ -103,16 +109,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) > self.num_layers:
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.layer = self.layer_default[0]
|
||||
self.layer_idx = self.layer_default[1]
|
||||
def reset_clip_options(self):
|
||||
self.layer = self.options_default[0]
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
@ -122,17 +131,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
for x in tokens:
|
||||
tokens_temp = []
|
||||
for y in x:
|
||||
if isinstance(y, int):
|
||||
if isinstance(y, numbers.Integral):
|
||||
if y == token_dict_size: #EOS token
|
||||
y = -1
|
||||
tokens_temp += [y]
|
||||
tokens_temp += [int(y)]
|
||||
else:
|
||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||
embedding_weights += [y]
|
||||
tokens_temp += [next_new_token]
|
||||
next_new_token += 1
|
||||
else:
|
||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.special_tokens["pad"]]
|
||||
out_tokens += [tokens_temp]
|
||||
@ -160,40 +169,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
if self.enable_attention_masks or self.zero_out_masked:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||
end_token = self.special_tokens.get("end", -1)
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == max_token:
|
||||
if tokens[x, y] == end_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
z = outputs[0].float()
|
||||
else:
|
||||
z = outputs[1]
|
||||
z = outputs[1].float()
|
||||
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
if self.zero_out_masked:
|
||||
z *= attention_mask.unsqueeze(-1).float()
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
return z.float(), pooled_output
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
|
||||
return z, pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_projection" in sd:
|
||||
self.text_projection[:] = sd.pop("text_projection")
|
||||
if "text_projection.weight" in sd:
|
||||
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
def parse_parentheses(string):
|
||||
@ -328,9 +340,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
else:
|
||||
embed = torch.load(embed_path, map_location="cpu")
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print()
|
||||
print("error loading embedding, skipping loading:", embedding_name)
|
||||
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||
return None
|
||||
|
||||
if embed_out is None:
|
||||
@ -354,11 +364,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
if has_start_token:
|
||||
@ -369,6 +380,14 @@ class SDTokenizer:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
elif pad_with_end:
|
||||
self.pad_token = self.end_token
|
||||
else:
|
||||
self.pad_token = 0
|
||||
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
|
||||
@ -401,10 +420,6 @@ class SDTokenizer:
|
||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||
'''
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
@ -420,7 +435,7 @@ class SDTokenizer:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
if len(embed.shape) == 1:
|
||||
tokens.append([(embed, weight)])
|
||||
@ -456,7 +471,7 @@ class SDTokenizer:
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
@ -469,7 +484,9 @@ class SDTokenizer:
|
||||
#fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
@ -497,17 +514,27 @@ class SD1Tokenizer:
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
||||
super().__init__()
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
if name is not None:
|
||||
self.clip_name = name
|
||||
self.clip = "{}".format(self.clip_name)
|
||||
else:
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
getattr(self, self.clip).clip_layer(layer_idx)
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
getattr(self, self.clip).reset_clip_layer()
|
||||
def set_clip_options(self, options):
|
||||
getattr(self, self.clip).set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
getattr(self, self.clip).reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
|
@ -1,5 +1,4 @@
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
|
150
comfy/sd3_clip.py
Normal file
150
comfy/sd3_clip.py
Normal file
@ -0,0 +1,150 @@
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
if dtype_t5 is None:
|
||||
dtype_t5 = dtype
|
||||
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
|
||||
dtype_t5 = dtype
|
||||
|
||||
if not comfy.model_management.supports_cast(device, dtype_t5):
|
||||
dtype_t5 = dtype
|
||||
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.set_clip_options(options)
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.set_clip_options(options)
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.reset_clip_options()
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.reset_clip_options()
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||
lg_out = None
|
||||
pooled = None
|
||||
out = None
|
||||
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
if self.clip_l is not None:
|
||||
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
g_out = None
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if lg_out is not None:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
return out, pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return self.clip_g.load_sd(sd)
|
||||
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
return SD3ClipModel_
|
@ -39,14 +39,15 @@ class SDXLClipModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.clip_l.clip_layer(layer_idx)
|
||||
self.clip_g.clip_layer(layer_idx)
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.clip_g.set_clip_options(options)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.clip_g.reset_clip_layer()
|
||||
self.clip_l.reset_clip_layer()
|
||||
def reset_clip_options(self):
|
||||
self.clip_g.reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
@ -64,3 +65,25 @@ class SDXLClipModel(torch.nn.Module):
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||
|
@ -5,6 +5,8 @@ from . import utils
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -40,15 +42,20 @@ class SD15(supported_models_base.BASE):
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
replace_prefix["cond_stage_model."] = "clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||
for p in pop_keys:
|
||||
if p in state_dict:
|
||||
state_dict.pop(p)
|
||||
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
class SD20(supported_models_base.BASE):
|
||||
@ -60,22 +67,28 @@ class SD20(supported_models_base.BASE):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict[k]
|
||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
out = state_dict.get(k, None)
|
||||
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
||||
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -85,7 +98,7 @@ class SD20(supported_models_base.BASE):
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
|
||||
class SD21UnclipL(SD20):
|
||||
@ -131,11 +144,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
@ -148,7 +160,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
|
||||
class SDXL(supported_models_base.BASE):
|
||||
@ -164,7 +176,18 @@ class SDXL(supported_models_base.BASE):
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if "v_pred" in state_dict:
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
||||
self.sampling_settings["sigma_data"] = 0.5
|
||||
self.sampling_settings["sigma_max"] = 80.0
|
||||
self.sampling_settings["sigma_min"] = 0.002
|
||||
return model_base.ModelType.EDM
|
||||
elif "edm_vpred.sigma_max" in state_dict:
|
||||
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
|
||||
if "edm_vpred.sigma_min" in state_dict:
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
@ -179,32 +202,34 @@ class SDXL(supported_models_base.BASE):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
|
||||
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
keys_to_replace = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||
for k in state_dict:
|
||||
if k.startswith("clip_l"):
|
||||
state_dict_g[k] = state_dict[k]
|
||||
|
||||
state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
|
||||
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||
for p in pop_keys:
|
||||
if p in state_dict_g:
|
||||
state_dict_g.pop(p)
|
||||
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
class SSD1B(SDXL):
|
||||
@ -227,6 +252,26 @@ class Segmind_Vega(SDXL):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
class KOALA_700M(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 2, 5],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
class KOALA_1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 2, 6],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
class SVD_img2vid(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -239,6 +284,12 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
@ -249,9 +300,44 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
out = model_base.SVD_img2vid(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 256,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
vae_key_prefix = ["conditioner.embedders.1.encoder."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_u(self, device=device)
|
||||
return out
|
||||
|
||||
class SV3D_p(SV3D_u):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 1280,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_p(self, device=device)
|
||||
return out
|
||||
|
||||
class Stable_Zero123(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
@ -267,6 +353,11 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
required_keys = {
|
||||
"cc_projection.weight": None,
|
||||
"cc_projection.bias": None,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
@ -275,7 +366,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
@ -306,5 +397,166 @@ class SD_X4Upscaler(SD20):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||
class Stable_Cascade_C(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'c',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_Prior
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoder."]
|
||||
clip_vision_prefix = "clip_l_vision."
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
key_list = list(state_dict.keys())
|
||||
for y in ["weight", "bias"]:
|
||||
suffix = "in_proj_{}".format(y)
|
||||
keys = filter(lambda a: a.endswith(suffix), key_list)
|
||||
for k_from in keys:
|
||||
weights = state_dict.pop(k_from)
|
||||
prefix = k_from[:-(len(suffix) + 1)]
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["to_q", "to_k", "to_v"]
|
||||
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
if "clip_g.text_projection" in state_dict:
|
||||
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'b',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_B
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
clip_vision_prefix = None
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(SD15):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SD15_instructpix2pix(self, device=device)
|
||||
|
||||
class SDXL_instructpix2pix(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
class SD3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"in_channels": 16,
|
||||
"pos_embed_scaling_factor": None,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD3(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
clip_l = False
|
||||
clip_g = False
|
||||
t5 = False
|
||||
dtype_t5 = None
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "dit1.0",
|
||||
}
|
||||
|
||||
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.StableAudio1
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
vae_key_prefix = ["pretransform.model."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
|
||||
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
|
||||
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
|
||||
state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@ -16,20 +16,28 @@ class BASE:
|
||||
"num_head_channels": 64,
|
||||
}
|
||||
|
||||
required_keys = {}
|
||||
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
vae_key_prefix = ["first_stage_model."]
|
||||
text_encoder_key_prefix = ["cond_stage_model."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
for k in s.unet_config:
|
||||
if s.unet_config[k] != unet_config[k]:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
if state_dict is not None:
|
||||
for k in s.required_keys:
|
||||
if k not in state_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
@ -39,7 +47,8 @@ class BASE:
|
||||
return self.unet_config["in_channels"] > 4
|
||||
|
||||
def __init__(self, unet_config):
|
||||
self.unet_config = unet_config
|
||||
self.unet_config = unet_config.copy()
|
||||
self.sampling_settings = self.sampling_settings.copy()
|
||||
self.latent_format = self.latent_format()
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
@ -54,6 +63,7 @@ class BASE:
|
||||
return out
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
@ -63,7 +73,7 @@ class BASE:
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||
@ -77,8 +87,9 @@ class BASE:
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_manual_cast(self, manual_cast_dtype):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
@ -153,7 +153,13 @@ class Adapter(nn.Module):
|
||||
features.append(None)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
features = features[::-1]
|
||||
|
||||
if self.xl:
|
||||
return {"input": features[1:], "middle": features[:1]}
|
||||
else:
|
||||
return {"input": features}
|
||||
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
@ -290,4 +296,4 @@ class Adapter_light(nn.Module):
|
||||
features.append(None)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
return {"input": features[::-1]}
|
||||
|
238
comfy/t5.py
Normal file
238
comfy/t5.py
Normal file
@ -0,0 +1,238 @@
|
||||
import torch
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
}
|
||||
|
||||
class T5DenseActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act(self.wi(x))
|
||||
# x = self.dropout(x)
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
# x = self.dropout(x)
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
|
||||
super().__init__()
|
||||
if gated_act:
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||
else:
|
||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
# x = x + self.dropout(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = operations.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(
|
||||
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
|
||||
if past_bias is not None:
|
||||
if mask is not None:
|
||||
mask = mask + past_bias
|
||||
else:
|
||||
mask = past_bias
|
||||
|
||||
out = optimized_attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||
return self.o(out), past_bias
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
normed_hidden_states = self.layer_norm(x)
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
|
||||
# x = x + self.dropout(attention_output)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.block = torch.nn.ModuleList(
|
||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
intermediate = None
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
||||
past_bias = None
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, mask, past_bias, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.shared = embeddings
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids)
|
||||
return self.encoder(x, *args, **kwargs)
|
22
comfy/t5_config_base.json
Normal file
22
comfy/t5_config_base.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 3072,
|
||||
"d_kv": 64,
|
||||
"d_model": 768,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "relu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": false,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 12,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 32128
|
||||
}
|
22
comfy/t5_config_xxl.json
Normal file
22
comfy/t5_config_xxl.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 32128
|
||||
}
|
125
comfy/t5_tokenizer/special_tokens_map.json
Normal file
125
comfy/t5_tokenizer/special_tokens_map.json
Normal file
@ -0,0 +1,125 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
129428
comfy/t5_tokenizer/tokenizer.json
Normal file
129428
comfy/t5_tokenizer/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
939
comfy/t5_tokenizer/tokenizer_config.json
Normal file
939
comfy/t5_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,939 @@
|
||||
{
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32000": {
|
||||
"content": "<extra_id_99>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<extra_id_98>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<extra_id_97>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32003": {
|
||||
"content": "<extra_id_96>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32004": {
|
||||
"content": "<extra_id_95>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32005": {
|
||||
"content": "<extra_id_94>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32006": {
|
||||
"content": "<extra_id_93>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32007": {
|
||||
"content": "<extra_id_92>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32008": {
|
||||
"content": "<extra_id_91>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32009": {
|
||||
"content": "<extra_id_90>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32010": {
|
||||
"content": "<extra_id_89>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32011": {
|
||||
"content": "<extra_id_88>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32012": {
|
||||
"content": "<extra_id_87>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32013": {
|
||||
"content": "<extra_id_86>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32014": {
|
||||
"content": "<extra_id_85>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32015": {
|
||||
"content": "<extra_id_84>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32016": {
|
||||
"content": "<extra_id_83>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32017": {
|
||||
"content": "<extra_id_82>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32018": {
|
||||
"content": "<extra_id_81>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32019": {
|
||||
"content": "<extra_id_80>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32020": {
|
||||
"content": "<extra_id_79>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32021": {
|
||||
"content": "<extra_id_78>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32022": {
|
||||
"content": "<extra_id_77>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32023": {
|
||||
"content": "<extra_id_76>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32024": {
|
||||
"content": "<extra_id_75>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32025": {
|
||||
"content": "<extra_id_74>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32026": {
|
||||
"content": "<extra_id_73>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32027": {
|
||||
"content": "<extra_id_72>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32028": {
|
||||
"content": "<extra_id_71>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32029": {
|
||||
"content": "<extra_id_70>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32030": {
|
||||
"content": "<extra_id_69>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32031": {
|
||||
"content": "<extra_id_68>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32032": {
|
||||
"content": "<extra_id_67>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32033": {
|
||||
"content": "<extra_id_66>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32034": {
|
||||
"content": "<extra_id_65>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32035": {
|
||||
"content": "<extra_id_64>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32036": {
|
||||
"content": "<extra_id_63>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32037": {
|
||||
"content": "<extra_id_62>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32038": {
|
||||
"content": "<extra_id_61>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32039": {
|
||||
"content": "<extra_id_60>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32040": {
|
||||
"content": "<extra_id_59>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32041": {
|
||||
"content": "<extra_id_58>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32042": {
|
||||
"content": "<extra_id_57>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32043": {
|
||||
"content": "<extra_id_56>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32044": {
|
||||
"content": "<extra_id_55>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32045": {
|
||||
"content": "<extra_id_54>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32046": {
|
||||
"content": "<extra_id_53>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32047": {
|
||||
"content": "<extra_id_52>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32048": {
|
||||
"content": "<extra_id_51>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32049": {
|
||||
"content": "<extra_id_50>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32050": {
|
||||
"content": "<extra_id_49>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32051": {
|
||||
"content": "<extra_id_48>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32052": {
|
||||
"content": "<extra_id_47>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32053": {
|
||||
"content": "<extra_id_46>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32054": {
|
||||
"content": "<extra_id_45>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32055": {
|
||||
"content": "<extra_id_44>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32056": {
|
||||
"content": "<extra_id_43>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32057": {
|
||||
"content": "<extra_id_42>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32058": {
|
||||
"content": "<extra_id_41>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32059": {
|
||||
"content": "<extra_id_40>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32060": {
|
||||
"content": "<extra_id_39>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32061": {
|
||||
"content": "<extra_id_38>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32062": {
|
||||
"content": "<extra_id_37>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32063": {
|
||||
"content": "<extra_id_36>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32064": {
|
||||
"content": "<extra_id_35>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32065": {
|
||||
"content": "<extra_id_34>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32066": {
|
||||
"content": "<extra_id_33>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32067": {
|
||||
"content": "<extra_id_32>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32068": {
|
||||
"content": "<extra_id_31>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32069": {
|
||||
"content": "<extra_id_30>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32070": {
|
||||
"content": "<extra_id_29>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32071": {
|
||||
"content": "<extra_id_28>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32072": {
|
||||
"content": "<extra_id_27>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32073": {
|
||||
"content": "<extra_id_26>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32074": {
|
||||
"content": "<extra_id_25>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32075": {
|
||||
"content": "<extra_id_24>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32076": {
|
||||
"content": "<extra_id_23>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32077": {
|
||||
"content": "<extra_id_22>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32078": {
|
||||
"content": "<extra_id_21>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32079": {
|
||||
"content": "<extra_id_20>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32080": {
|
||||
"content": "<extra_id_19>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32081": {
|
||||
"content": "<extra_id_18>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32082": {
|
||||
"content": "<extra_id_17>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32083": {
|
||||
"content": "<extra_id_16>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32084": {
|
||||
"content": "<extra_id_15>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32085": {
|
||||
"content": "<extra_id_14>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32086": {
|
||||
"content": "<extra_id_13>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32087": {
|
||||
"content": "<extra_id_12>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32088": {
|
||||
"content": "<extra_id_11>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32089": {
|
||||
"content": "<extra_id_10>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32090": {
|
||||
"content": "<extra_id_9>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32091": {
|
||||
"content": "<extra_id_8>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32092": {
|
||||
"content": "<extra_id_7>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32093": {
|
||||
"content": "<extra_id_6>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32094": {
|
||||
"content": "<extra_id_5>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32095": {
|
||||
"content": "<extra_id_4>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32096": {
|
||||
"content": "<extra_id_3>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32097": {
|
||||
"content": "<extra_id_2>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32098": {
|
||||
"content": "<extra_id_1>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32099": {
|
||||
"content": "<extra_id_0>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "</s>",
|
||||
"extra_ids": 100,
|
||||
"legacy": false,
|
||||
"model_max_length": 512,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
@ -25,18 +25,19 @@ class Block(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
def Encoder():
|
||||
def Encoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
|
||||
def Decoder():
|
||||
|
||||
def Decoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
@ -47,12 +48,13 @@ class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path=None, decoder_path=None):
|
||||
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.taesd_encoder = Encoder()
|
||||
self.taesd_decoder = Decoder()
|
||||
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
||||
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
if decoder_path is not None:
|
||||
@ -69,9 +71,9 @@ class TAESD(nn.Module):
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
def decode(self, x):
|
||||
x_sample = self.taesd_decoder(x * self.vae_scale)
|
||||
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
||||
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||
|
32
comfy/types.py
Normal file
32
comfy/types.py
Normal file
@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
"""Function signature protocol on comfy.model_base.BaseModel.apply_model"""
|
||||
|
||||
def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class UnetApplyConds(TypedDict):
|
||||
"""Optional conditions for unet apply function."""
|
||||
|
||||
c_concat: Optional[torch.Tensor]
|
||||
c_crossattn: Optional[torch.Tensor]
|
||||
control: Optional[torch.Tensor]
|
||||
transformer_options: Optional[dict]
|
||||
|
||||
|
||||
class UnetParams(TypedDict):
|
||||
# Tensor of shape [B, C, H, W]
|
||||
input: torch.Tensor
|
||||
# Tensor of shape [B]
|
||||
timestep: torch.Tensor
|
||||
c: UnetApplyConds
|
||||
# List of [0, 1], [0], [1], ...
|
||||
# 0 means conditional, 1 means conditional unconditional
|
||||
cond_or_uncond: List[int]
|
||||
|
||||
|
||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
181
comfy/utils.py
181
comfy/utils.py
@ -5,6 +5,8 @@ import comfy.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
@ -14,14 +16,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
@ -98,8 +100,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
|
||||
return sd
|
||||
|
||||
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
||||
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
||||
|
||||
tp = "{}text_projection.weight".format(prefix_from)
|
||||
if tp in sd:
|
||||
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
||||
|
||||
tp = "{}text_projection".format(prefix_from)
|
||||
if tp in sd:
|
||||
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
||||
return sd
|
||||
|
||||
|
||||
UNET_MAP_ATTENTIONS = {
|
||||
"proj_in.weight",
|
||||
"proj_in.bias",
|
||||
@ -169,6 +185,8 @@ UNET_MAP_BASIC = {
|
||||
}
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
@ -232,11 +250,93 @@ def unet_to_diffusers(unet_config):
|
||||
|
||||
return diffusers_unet_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size):
|
||||
if tensor.shape[0] > batch_size:
|
||||
return tensor[:batch_size]
|
||||
elif tensor.shape[0] < batch_size:
|
||||
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
MMDIT_MAP_BASIC = {
|
||||
("context_embedder.bias", "context_embedder.bias"),
|
||||
("context_embedder.weight", "context_embedder.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("pos_embed", "pos_embed.pos_embed"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
}
|
||||
|
||||
MMDIT_MAP_BLOCK = {
|
||||
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
||||
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
||||
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
||||
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
||||
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
||||
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
||||
}
|
||||
|
||||
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map = {}
|
||||
|
||||
depth = mmdit_config.get("depth", 0)
|
||||
num_blocks = mmdit_config.get("num_blocks", depth)
|
||||
for i in range(num_blocks):
|
||||
block_from = "transformer_blocks.{}".format(i)
|
||||
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
||||
|
||||
offset = depth * 64
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(block_from)
|
||||
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
for k in MMDIT_MAP_BLOCK:
|
||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||
|
||||
map_basic = MMDIT_MAP_BASIC.copy()
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
||||
|
||||
for k in map_basic:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
elif tensor.shape[dim] < batch_size:
|
||||
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
||||
return tensor
|
||||
|
||||
def resize_to_batch_size(tensor, batch_size):
|
||||
@ -278,8 +378,11 @@ def set_attr(obj, attr, value):
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||
del prev
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
@ -405,34 +508,52 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
for y in range(0, s.shape[2], tile_y - overlap):
|
||||
for x in range(0, s.shape[3], tile_x - overlap):
|
||||
x = max(0, min(s.shape[-1] - overlap, x))
|
||||
y = max(0, min(s.shape[-2] - overlap, y))
|
||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
||||
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
|
||||
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
m = mask.narrow(d, t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
|
||||
o = out
|
||||
o_d = out_div
|
||||
for d in range(dims):
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
|
||||
o += ps * mask
|
||||
o_d += mask
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Xiangyu Chen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,29 +0,0 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2022 Kai Zhang (cskaizhang@gmail.com, https://cszn.github.io/). All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2018-2022 BasicSR Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,121 +0,0 @@
|
||||
Creative Commons Legal Code
|
||||
|
||||
CC0 1.0 Universal
|
||||
|
||||
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
|
||||
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
|
||||
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
|
||||
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
|
||||
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
|
||||
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
|
||||
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
|
||||
HEREUNDER.
|
||||
|
||||
Statement of Purpose
|
||||
|
||||
The laws of most jurisdictions throughout the world automatically confer
|
||||
exclusive Copyright and Related Rights (defined below) upon the creator
|
||||
and subsequent owner(s) (each and all, an "owner") of an original work of
|
||||
authorship and/or a database (each, a "Work").
|
||||
|
||||
Certain owners wish to permanently relinquish those rights to a Work for
|
||||
the purpose of contributing to a commons of creative, cultural and
|
||||
scientific works ("Commons") that the public can reliably and without fear
|
||||
of later claims of infringement build upon, modify, incorporate in other
|
||||
works, reuse and redistribute as freely as possible in any form whatsoever
|
||||
and for any purposes, including without limitation commercial purposes.
|
||||
These owners may contribute to the Commons to promote the ideal of a free
|
||||
culture and the further production of creative, cultural and scientific
|
||||
works, or to gain reputation or greater distribution for their Work in
|
||||
part through the use and efforts of others.
|
||||
|
||||
For these and/or other purposes and motivations, and without any
|
||||
expectation of additional consideration or compensation, the person
|
||||
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
|
||||
is an owner of Copyright and Related Rights in the Work, voluntarily
|
||||
elects to apply CC0 to the Work and publicly distribute the Work under its
|
||||
terms, with knowledge of his or her Copyright and Related Rights in the
|
||||
Work and the meaning and intended legal effect of CC0 on those rights.
|
||||
|
||||
1. Copyright and Related Rights. A Work made available under CC0 may be
|
||||
protected by copyright and related or neighboring rights ("Copyright and
|
||||
Related Rights"). Copyright and Related Rights include, but are not
|
||||
limited to, the following:
|
||||
|
||||
i. the right to reproduce, adapt, distribute, perform, display,
|
||||
communicate, and translate a Work;
|
||||
ii. moral rights retained by the original author(s) and/or performer(s);
|
||||
iii. publicity and privacy rights pertaining to a person's image or
|
||||
likeness depicted in a Work;
|
||||
iv. rights protecting against unfair competition in regards to a Work,
|
||||
subject to the limitations in paragraph 4(a), below;
|
||||
v. rights protecting the extraction, dissemination, use and reuse of data
|
||||
in a Work;
|
||||
vi. database rights (such as those arising under Directive 96/9/EC of the
|
||||
European Parliament and of the Council of 11 March 1996 on the legal
|
||||
protection of databases, and under any national implementation
|
||||
thereof, including any amended or successor version of such
|
||||
directive); and
|
||||
vii. other similar, equivalent or corresponding rights throughout the
|
||||
world based on applicable law or treaty, and any national
|
||||
implementations thereof.
|
||||
|
||||
2. Waiver. To the greatest extent permitted by, but not in contravention
|
||||
of, applicable law, Affirmer hereby overtly, fully, permanently,
|
||||
irrevocably and unconditionally waives, abandons, and surrenders all of
|
||||
Affirmer's Copyright and Related Rights and associated claims and causes
|
||||
of action, whether now known or unknown (including existing as well as
|
||||
future claims and causes of action), in the Work (i) in all territories
|
||||
worldwide, (ii) for the maximum duration provided by applicable law or
|
||||
treaty (including future time extensions), (iii) in any current or future
|
||||
medium and for any number of copies, and (iv) for any purpose whatsoever,
|
||||
including without limitation commercial, advertising or promotional
|
||||
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
|
||||
member of the public at large and to the detriment of Affirmer's heirs and
|
||||
successors, fully intending that such Waiver shall not be subject to
|
||||
revocation, rescission, cancellation, termination, or any other legal or
|
||||
equitable action to disrupt the quiet enjoyment of the Work by the public
|
||||
as contemplated by Affirmer's express Statement of Purpose.
|
||||
|
||||
3. Public License Fallback. Should any part of the Waiver for any reason
|
||||
be judged legally invalid or ineffective under applicable law, then the
|
||||
Waiver shall be preserved to the maximum extent permitted taking into
|
||||
account Affirmer's express Statement of Purpose. In addition, to the
|
||||
extent the Waiver is so judged Affirmer hereby grants to each affected
|
||||
person a royalty-free, non transferable, non sublicensable, non exclusive,
|
||||
irrevocable and unconditional license to exercise Affirmer's Copyright and
|
||||
Related Rights in the Work (i) in all territories worldwide, (ii) for the
|
||||
maximum duration provided by applicable law or treaty (including future
|
||||
time extensions), (iii) in any current or future medium and for any number
|
||||
of copies, and (iv) for any purpose whatsoever, including without
|
||||
limitation commercial, advertising or promotional purposes (the
|
||||
"License"). The License shall be deemed effective as of the date CC0 was
|
||||
applied by Affirmer to the Work. Should any part of the License for any
|
||||
reason be judged legally invalid or ineffective under applicable law, such
|
||||
partial invalidity or ineffectiveness shall not invalidate the remainder
|
||||
of the License, and in such case Affirmer hereby affirms that he or she
|
||||
will not (i) exercise any of his or her remaining Copyright and Related
|
||||
Rights in the Work or (ii) assert any associated claims and causes of
|
||||
action with respect to the Work, in either case contrary to Affirmer's
|
||||
express Statement of Purpose.
|
||||
|
||||
4. Limitations and Disclaimers.
|
||||
|
||||
a. No trademark or patent rights held by Affirmer are waived, abandoned,
|
||||
surrendered, licensed or otherwise affected by this document.
|
||||
b. Affirmer offers the Work as-is and makes no representations or
|
||||
warranties of any kind concerning the Work, express, implied,
|
||||
statutory or otherwise, including without limitation warranties of
|
||||
title, merchantability, fitness for a particular purpose, non
|
||||
infringement, or the absence of latent or other defects, accuracy, or
|
||||
the present or absence of errors, whether or not discoverable, all to
|
||||
the greatest extent permissible under applicable law.
|
||||
c. Affirmer disclaims responsibility for clearing rights of other persons
|
||||
that may apply to the Work or any use thereof, including without
|
||||
limitation any person's Copyright and Related Rights in the Work.
|
||||
Further, Affirmer disclaims responsibility for obtaining any necessary
|
||||
consents, permissions or other rights required for any use of the
|
||||
Work.
|
||||
d. Affirmer understands and acknowledges that Creative Commons is not a
|
||||
party to this document and has no duty or obligation with respect to
|
||||
this CC0 or use of the Work.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] [SwinIR Authors]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] [SwinIR Authors]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] Samsung Research
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,694 +0,0 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Model adapted from advimman's lama project: https://github.com/advimman/lama
|
||||
"""
|
||||
|
||||
# Fast Fourier Convolution NeurIPS 2020
|
||||
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
|
||||
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms.functional import InterpolationMode, rotate
|
||||
|
||||
|
||||
class LearnableSpatialTransformWrapper(nn.Module):
|
||||
def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
|
||||
super().__init__()
|
||||
self.impl = impl
|
||||
self.angle = torch.rand(1) * angle_init_range
|
||||
if train_angle:
|
||||
self.angle = nn.Parameter(self.angle, requires_grad=True)
|
||||
self.pad_coef = pad_coef
|
||||
|
||||
def forward(self, x):
|
||||
if torch.is_tensor(x):
|
||||
return self.inverse_transform(self.impl(self.transform(x)), x)
|
||||
elif isinstance(x, tuple):
|
||||
x_trans = tuple(self.transform(elem) for elem in x)
|
||||
y_trans = self.impl(x_trans)
|
||||
return tuple(
|
||||
self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input type {type(x)}")
|
||||
|
||||
def transform(self, x):
|
||||
height, width = x.shape[2:]
|
||||
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
||||
x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect")
|
||||
x_padded_rotated = rotate(
|
||||
x_padded, self.angle.to(x_padded), InterpolationMode.BILINEAR, fill=0
|
||||
)
|
||||
|
||||
return x_padded_rotated
|
||||
|
||||
def inverse_transform(self, y_padded_rotated, orig_x):
|
||||
height, width = orig_x.shape[2:]
|
||||
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
||||
|
||||
y_padded = rotate(
|
||||
y_padded_rotated,
|
||||
-self.angle.to(y_padded_rotated),
|
||||
InterpolationMode.BILINEAR,
|
||||
fill=0,
|
||||
)
|
||||
y_height, y_width = y_padded.shape[2:]
|
||||
y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
|
||||
return y
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
res = x * y.expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class FourierUnit(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
groups=1,
|
||||
spatial_scale_factor=None,
|
||||
spatial_scale_mode="bilinear",
|
||||
spectral_pos_encoding=False,
|
||||
use_se=False,
|
||||
se_kwargs=None,
|
||||
ffc3d=False,
|
||||
fft_norm="ortho",
|
||||
):
|
||||
# bn_layer not used
|
||||
super(FourierUnit, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
self.conv_layer = torch.nn.Conv2d(
|
||||
in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
||||
out_channels=out_channels * 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=self.groups,
|
||||
bias=False,
|
||||
)
|
||||
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
|
||||
# squeeze and excitation block
|
||||
self.use_se = use_se
|
||||
if use_se:
|
||||
if se_kwargs is None:
|
||||
se_kwargs = {}
|
||||
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
||||
|
||||
self.spatial_scale_factor = spatial_scale_factor
|
||||
self.spatial_scale_mode = spatial_scale_mode
|
||||
self.spectral_pos_encoding = spectral_pos_encoding
|
||||
self.ffc3d = ffc3d
|
||||
self.fft_norm = fft_norm
|
||||
|
||||
def forward(self, x):
|
||||
half_check = False
|
||||
if x.type() == "torch.cuda.HalfTensor":
|
||||
# half only works on gpu anyway
|
||||
half_check = True
|
||||
|
||||
batch = x.shape[0]
|
||||
|
||||
if self.spatial_scale_factor is not None:
|
||||
orig_size = x.shape[-2:]
|
||||
x = F.interpolate(
|
||||
x,
|
||||
scale_factor=self.spatial_scale_factor,
|
||||
mode=self.spatial_scale_mode,
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
# (batch, c, h, w/2+1, 2)
|
||||
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
||||
if half_check == True:
|
||||
ffted = torch.fft.rfftn(
|
||||
x.float(), dim=fft_dim, norm=self.fft_norm
|
||||
) # .type(torch.cuda.HalfTensor)
|
||||
else:
|
||||
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
||||
|
||||
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
||||
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
||||
ffted = ffted.view(
|
||||
(
|
||||
batch,
|
||||
-1,
|
||||
)
|
||||
+ ffted.size()[3:]
|
||||
)
|
||||
|
||||
if self.spectral_pos_encoding:
|
||||
height, width = ffted.shape[-2:]
|
||||
coords_vert = (
|
||||
torch.linspace(0, 1, height)[None, None, :, None]
|
||||
.expand(batch, 1, height, width)
|
||||
.to(ffted)
|
||||
)
|
||||
coords_hor = (
|
||||
torch.linspace(0, 1, width)[None, None, None, :]
|
||||
.expand(batch, 1, height, width)
|
||||
.to(ffted)
|
||||
)
|
||||
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
||||
|
||||
if self.use_se:
|
||||
ffted = self.se(ffted)
|
||||
|
||||
if half_check == True:
|
||||
ffted = self.conv_layer(ffted.half()) # (batch, c*2, h, w/2+1)
|
||||
else:
|
||||
ffted = self.conv_layer(
|
||||
ffted
|
||||
) # .type(torch.cuda.FloatTensor) # (batch, c*2, h, w/2+1)
|
||||
|
||||
ffted = self.relu(self.bn(ffted))
|
||||
# forcing to be always float
|
||||
ffted = ffted.float()
|
||||
|
||||
ffted = (
|
||||
ffted.view(
|
||||
(
|
||||
batch,
|
||||
-1,
|
||||
2,
|
||||
)
|
||||
+ ffted.size()[2:]
|
||||
)
|
||||
.permute(0, 1, 3, 4, 2)
|
||||
.contiguous()
|
||||
) # (batch,c, t, h, w/2+1, 2)
|
||||
|
||||
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
||||
|
||||
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
||||
output = torch.fft.irfftn(
|
||||
ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
|
||||
)
|
||||
|
||||
if half_check == True:
|
||||
output = output.half()
|
||||
|
||||
if self.spatial_scale_factor is not None:
|
||||
output = F.interpolate(
|
||||
output,
|
||||
size=orig_size,
|
||||
mode=self.spatial_scale_mode,
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SpectralTransform(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
groups=1,
|
||||
enable_lfu=True,
|
||||
separable_fu=False,
|
||||
**fu_kwargs,
|
||||
):
|
||||
# bn_layer not used
|
||||
super(SpectralTransform, self).__init__()
|
||||
self.enable_lfu = enable_lfu
|
||||
if stride == 2:
|
||||
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
self.stride = stride
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(out_channels // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
fu_class = FourierUnit
|
||||
self.fu = fu_class(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
||||
if self.enable_lfu:
|
||||
self.lfu = fu_class(out_channels // 2, out_channels // 2, groups)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.conv1(x)
|
||||
output = self.fu(x)
|
||||
|
||||
if self.enable_lfu:
|
||||
_, c, h, _ = x.shape
|
||||
split_no = 2
|
||||
split_s = h // split_no
|
||||
xs = torch.cat(
|
||||
torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
|
||||
).contiguous()
|
||||
xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
|
||||
xs = self.lfu(xs)
|
||||
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
||||
else:
|
||||
xs = 0
|
||||
|
||||
output = self.conv2(x + output + xs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FFC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
ratio_gin,
|
||||
ratio_gout,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
enable_lfu=True,
|
||||
padding_type="reflect",
|
||||
gated=False,
|
||||
**spectral_kwargs,
|
||||
):
|
||||
super(FFC, self).__init__()
|
||||
|
||||
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
||||
self.stride = stride
|
||||
|
||||
in_cg = int(in_channels * ratio_gin)
|
||||
in_cl = in_channels - in_cg
|
||||
out_cg = int(out_channels * ratio_gout)
|
||||
out_cl = out_channels - out_cg
|
||||
# groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
||||
# groups_l = 1 if groups == 1 else groups - groups_g
|
||||
|
||||
self.ratio_gin = ratio_gin
|
||||
self.ratio_gout = ratio_gout
|
||||
self.global_in_num = in_cg
|
||||
|
||||
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
||||
self.convl2l = module(
|
||||
in_cl,
|
||||
out_cl,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode=padding_type,
|
||||
)
|
||||
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
||||
self.convl2g = module(
|
||||
in_cl,
|
||||
out_cg,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode=padding_type,
|
||||
)
|
||||
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
||||
self.convg2l = module(
|
||||
in_cg,
|
||||
out_cl,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode=padding_type,
|
||||
)
|
||||
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
||||
self.convg2g = module(
|
||||
in_cg,
|
||||
out_cg,
|
||||
stride,
|
||||
1 if groups == 1 else groups // 2,
|
||||
enable_lfu,
|
||||
**spectral_kwargs,
|
||||
)
|
||||
|
||||
self.gated = gated
|
||||
module = (
|
||||
nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
||||
)
|
||||
self.gate = module(in_channels, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x_l, x_g = x if type(x) is tuple else (x, 0)
|
||||
out_xl, out_xg = 0, 0
|
||||
|
||||
if self.gated:
|
||||
total_input_parts = [x_l]
|
||||
if torch.is_tensor(x_g):
|
||||
total_input_parts.append(x_g)
|
||||
total_input = torch.cat(total_input_parts, dim=1)
|
||||
|
||||
gates = torch.sigmoid(self.gate(total_input))
|
||||
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
||||
else:
|
||||
g2l_gate, l2g_gate = 1, 1
|
||||
|
||||
if self.ratio_gout != 1:
|
||||
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
||||
if self.ratio_gout != 0:
|
||||
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
|
||||
|
||||
return out_xl, out_xg
|
||||
|
||||
|
||||
class FFC_BN_ACT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
ratio_gin,
|
||||
ratio_gout,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
activation_layer=nn.Identity,
|
||||
padding_type="reflect",
|
||||
enable_lfu=True,
|
||||
**kwargs,
|
||||
):
|
||||
super(FFC_BN_ACT, self).__init__()
|
||||
self.ffc = FFC(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
ratio_gin,
|
||||
ratio_gout,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
enable_lfu,
|
||||
padding_type=padding_type,
|
||||
**kwargs,
|
||||
)
|
||||
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
||||
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
||||
global_channels = int(out_channels * ratio_gout)
|
||||
self.bn_l = lnorm(out_channels - global_channels)
|
||||
self.bn_g = gnorm(global_channels)
|
||||
|
||||
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
||||
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
||||
self.act_l = lact(inplace=True)
|
||||
self.act_g = gact(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x_l, x_g = self.ffc(x)
|
||||
x_l = self.act_l(self.bn_l(x_l))
|
||||
x_g = self.act_g(self.bn_g(x_g))
|
||||
return x_l, x_g
|
||||
|
||||
|
||||
class FFCResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
padding_type,
|
||||
norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
dilation=1,
|
||||
spatial_transform_kwargs=None,
|
||||
inline=False,
|
||||
**conv_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = FFC_BN_ACT(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=activation_layer,
|
||||
padding_type=padding_type,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.conv2 = FFC_BN_ACT(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=activation_layer,
|
||||
padding_type=padding_type,
|
||||
**conv_kwargs,
|
||||
)
|
||||
if spatial_transform_kwargs is not None:
|
||||
self.conv1 = LearnableSpatialTransformWrapper(
|
||||
self.conv1, **spatial_transform_kwargs
|
||||
)
|
||||
self.conv2 = LearnableSpatialTransformWrapper(
|
||||
self.conv2, **spatial_transform_kwargs
|
||||
)
|
||||
self.inline = inline
|
||||
|
||||
def forward(self, x):
|
||||
if self.inline:
|
||||
x_l, x_g = (
|
||||
x[:, : -self.conv1.ffc.global_in_num],
|
||||
x[:, -self.conv1.ffc.global_in_num :],
|
||||
)
|
||||
else:
|
||||
x_l, x_g = x if type(x) is tuple else (x, 0)
|
||||
|
||||
id_l, id_g = x_l, x_g
|
||||
|
||||
x_l, x_g = self.conv1((x_l, x_g))
|
||||
x_l, x_g = self.conv2((x_l, x_g))
|
||||
|
||||
x_l, x_g = id_l + x_l, id_g + x_g
|
||||
out = x_l, x_g
|
||||
if self.inline:
|
||||
out = torch.cat(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class ConcatTupleLayer(nn.Module):
|
||||
def forward(self, x):
|
||||
assert isinstance(x, tuple)
|
||||
x_l, x_g = x
|
||||
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
||||
if not torch.is_tensor(x_g):
|
||||
return x_l
|
||||
return torch.cat(x, dim=1)
|
||||
|
||||
|
||||
class FFCResNetGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
ngf=64,
|
||||
n_downsampling=3,
|
||||
n_blocks=18,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
padding_type="reflect",
|
||||
activation_layer=nn.ReLU,
|
||||
up_norm_layer=nn.BatchNorm2d,
|
||||
up_activation=nn.ReLU(True),
|
||||
init_conv_kwargs={},
|
||||
downsample_conv_kwargs={},
|
||||
resnet_conv_kwargs={},
|
||||
spatial_transform_layers=None,
|
||||
spatial_transform_kwargs={},
|
||||
max_features=1024,
|
||||
out_ffc=False,
|
||||
out_ffc_kwargs={},
|
||||
):
|
||||
assert n_blocks >= 0
|
||||
super().__init__()
|
||||
"""
|
||||
init_conv_kwargs = {'ratio_gin': 0, 'ratio_gout': 0, 'enable_lfu': False}
|
||||
downsample_conv_kwargs = {'ratio_gin': '${generator.init_conv_kwargs.ratio_gout}', 'ratio_gout': '${generator.downsample_conv_kwargs.ratio_gin}', 'enable_lfu': False}
|
||||
resnet_conv_kwargs = {'ratio_gin': 0.75, 'ratio_gout': '${generator.resnet_conv_kwargs.ratio_gin}', 'enable_lfu': False}
|
||||
spatial_transform_kwargs = {}
|
||||
out_ffc_kwargs = {}
|
||||
"""
|
||||
"""
|
||||
print(input_nc, output_nc, ngf, n_downsampling, n_blocks, norm_layer,
|
||||
padding_type, activation_layer,
|
||||
up_norm_layer, up_activation,
|
||||
spatial_transform_layers,
|
||||
add_out_act, max_features, out_ffc, file=sys.stderr)
|
||||
|
||||
4 3 64 3 18 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
|
||||
reflect <class 'torch.nn.modules.activation.ReLU'>
|
||||
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
|
||||
ReLU(inplace=True)
|
||||
None sigmoid 1024 False
|
||||
"""
|
||||
init_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
|
||||
downsample_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
|
||||
resnet_conv_kwargs = {
|
||||
"ratio_gin": 0.75,
|
||||
"ratio_gout": 0.75,
|
||||
"enable_lfu": False,
|
||||
}
|
||||
spatial_transform_kwargs = {}
|
||||
out_ffc_kwargs = {}
|
||||
|
||||
model = [
|
||||
nn.ReflectionPad2d(3),
|
||||
FFC_BN_ACT(
|
||||
input_nc,
|
||||
ngf,
|
||||
kernel_size=7,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=activation_layer,
|
||||
**init_conv_kwargs,
|
||||
),
|
||||
]
|
||||
|
||||
### downsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
if i == n_downsampling - 1:
|
||||
cur_conv_kwargs = dict(downsample_conv_kwargs)
|
||||
cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0)
|
||||
else:
|
||||
cur_conv_kwargs = downsample_conv_kwargs
|
||||
model += [
|
||||
FFC_BN_ACT(
|
||||
min(max_features, ngf * mult),
|
||||
min(max_features, ngf * mult * 2),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=activation_layer,
|
||||
**cur_conv_kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
mult = 2**n_downsampling
|
||||
feats_num_bottleneck = min(max_features, ngf * mult)
|
||||
|
||||
### resnet blocks
|
||||
for i in range(n_blocks):
|
||||
cur_resblock = FFCResnetBlock(
|
||||
feats_num_bottleneck,
|
||||
padding_type=padding_type,
|
||||
activation_layer=activation_layer,
|
||||
norm_layer=norm_layer,
|
||||
**resnet_conv_kwargs,
|
||||
)
|
||||
if spatial_transform_layers is not None and i in spatial_transform_layers:
|
||||
cur_resblock = LearnableSpatialTransformWrapper(
|
||||
cur_resblock, **spatial_transform_kwargs
|
||||
)
|
||||
model += [cur_resblock]
|
||||
|
||||
model += [ConcatTupleLayer()]
|
||||
|
||||
### upsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2 ** (n_downsampling - i)
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
min(max_features, ngf * mult),
|
||||
min(max_features, int(ngf * mult / 2)),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
),
|
||||
up_norm_layer(min(max_features, int(ngf * mult / 2))),
|
||||
up_activation,
|
||||
]
|
||||
|
||||
if out_ffc:
|
||||
model += [
|
||||
FFCResnetBlock(
|
||||
ngf,
|
||||
padding_type=padding_type,
|
||||
activation_layer=activation_layer,
|
||||
norm_layer=norm_layer,
|
||||
inline=True,
|
||||
**out_ffc_kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
model += [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
||||
]
|
||||
model.append(nn.Sigmoid())
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, image, mask):
|
||||
return self.model(torch.cat([image, mask], dim=1))
|
||||
|
||||
|
||||
class LaMa(nn.Module):
|
||||
def __init__(self, state_dict) -> None:
|
||||
super(LaMa, self).__init__()
|
||||
self.model_arch = "LaMa"
|
||||
self.sub_type = "Inpaint"
|
||||
self.in_nc = 4
|
||||
self.out_nc = 3
|
||||
self.scale = 1
|
||||
|
||||
self.min_size = None
|
||||
self.pad_mod = 8
|
||||
self.pad_to_square = False
|
||||
|
||||
self.model = FFCResNetGenerator(self.in_nc, self.out_nc)
|
||||
self.state = {
|
||||
k.replace("generator.model", "model.model"): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
self.supports_fp16 = False
|
||||
self.support_bf16 = True
|
||||
|
||||
self.load_state_dict(self.state, strict=False)
|
||||
|
||||
def forward(self, img, mask):
|
||||
masked_img = img * (1 - mask)
|
||||
inpainted_mask = mask * self.model.forward(masked_img, mask)
|
||||
result = inpainted_mask + (1 - mask) * img
|
||||
return result
|
@ -1,110 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CA_layer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(CA_layer, self).__init__()
|
||||
# global average pooling
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc(self.gap(x))
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class Simple_CA_layer(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super(Simple_CA_layer, self).__init__()
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=channel,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.fc(self.gap(x))
|
||||
|
||||
|
||||
class ECA_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class ECA_MaxPool_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_MaxPool_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.max_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,577 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSA.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from torch import einsum, nn
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, length=1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class Conv_PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Gated_Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = int(dim * mult)
|
||||
|
||||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=hidden_features * 2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.project_in(x)
|
||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||
x = F.gelu(x1) * x2
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
|
||||
# MBConv
|
||||
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate=0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce("b c h w -> b c", "mean"),
|
||||
nn.Linear(dim, hidden_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias=False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange("b c -> b c 1 1"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob=0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0.0 or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = (
|
||||
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||
> self.prob
|
||||
)
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
|
||||
def MBConv(
|
||||
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(
|
||||
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||
),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
# nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout=dropout)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# attention related classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
if self.with_pe:
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = rearrange(grid, "c i j -> (i j) c")
|
||||
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||
grid, "j ... -> 1 j ..."
|
||||
)
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = (
|
||||
*x.shape,
|
||||
x.device,
|
||||
self.heads,
|
||||
)
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# add positional bias
|
||||
if self.with_pe:
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(
|
||||
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||
)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||
|
||||
|
||||
class Block_Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.ps = window_size
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
# project for queries, keys, values
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||
h=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# attention
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(
|
||||
out,
|
||||
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||
x=h // self.ps,
|
||||
y=w // self.ps,
|
||||
head=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
)
|
||||
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention_grid(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention_grid, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class OSA_Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
ffn_bias=True,
|
||||
window_size=8,
|
||||
with_pe=False,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(OSA_Block, self).__init__()
|
||||
|
||||
w = window_size
|
||||
|
||||
self.layer = nn.Sequential(
|
||||
MBConv(
|
||||
channel_num,
|
||||
channel_num,
|
||||
downsample=False,
|
||||
expansion_rate=1,
|
||||
shrinkage_rate=0.25,
|
||||
),
|
||||
Rearrange(
|
||||
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # block-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
Rearrange(
|
||||
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # grid-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention_grid(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return out
|
@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSAG.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .esa import ESA
|
||||
from .OSA import OSA_Block
|
||||
|
||||
|
||||
class OSAG(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
block_num=4,
|
||||
ffn_bias=False,
|
||||
window_size=0,
|
||||
pe=False,
|
||||
):
|
||||
super(OSAG, self).__init__()
|
||||
|
||||
# print("window_size: %d" % (window_size))
|
||||
# print("with_pe", pe)
|
||||
# print("ffn_bias: %d" % (ffn_bias))
|
||||
|
||||
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||||
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||||
|
||||
# script_name = "." + block_script_name
|
||||
# package = __import__(script_name, fromlist=True)
|
||||
block_class = OSA_Block # getattr(package, block_class_name)
|
||||
group_list = []
|
||||
for _ in range(block_num):
|
||||
temp_res = block_class(
|
||||
channel_num,
|
||||
bias,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=window_size,
|
||||
with_pe=pe,
|
||||
)
|
||||
group_list.append(temp_res)
|
||||
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||||
self.residual_layer = nn.Sequential(*group_list)
|
||||
esa_channel = max(channel_num // 4, 16)
|
||||
self.esa = ESA(esa_channel, channel_num)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.residual_layer(x)
|
||||
out = out + x
|
||||
return self.esa(out)
|
@ -1,143 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OmniSR.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .OSAG import OSAG
|
||||
from .pixelshuffle import pixelshuffle_block
|
||||
|
||||
|
||||
class OmniSR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
**kwargs,
|
||||
):
|
||||
super(OmniSR, self).__init__()
|
||||
self.state = state_dict
|
||||
|
||||
bias = True # Fine to assume this for now
|
||||
block_num = 1 # Fine to assume this for now
|
||||
ffn_bias = True
|
||||
pe = True
|
||||
|
||||
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||
|
||||
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||
if up_scale - int(up_scale) > 0:
|
||||
print(
|
||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||
)
|
||||
up_scale = int(up_scale)
|
||||
res_num = 0
|
||||
for key in state_dict.keys():
|
||||
if "residual_layer" in key:
|
||||
temp_res_num = int(key.split(".")[1])
|
||||
if temp_res_num > res_num:
|
||||
res_num = temp_res_num
|
||||
res_num = res_num + 1 # zero-indexed
|
||||
|
||||
residual_layer = []
|
||||
self.res_num = res_num
|
||||
|
||||
if (
|
||||
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
|
||||
in state_dict.keys()
|
||||
):
|
||||
rel_pos_bias_weight = state_dict[
|
||||
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
|
||||
].shape[0]
|
||||
self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
|
||||
else:
|
||||
self.window_size = 8
|
||||
|
||||
self.up_scale = up_scale
|
||||
|
||||
for _ in range(res_num):
|
||||
temp_res = OSAG(
|
||||
channel_num=num_feat,
|
||||
bias=bias,
|
||||
block_num=block_num,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=self.window_size,
|
||||
pe=pe,
|
||||
)
|
||||
residual_layer.append(temp_res)
|
||||
self.residual_layer = nn.Sequential(*residual_layer)
|
||||
self.input = nn.Conv2d(
|
||||
in_channels=num_in_ch,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.output = nn.Conv2d(
|
||||
in_channels=num_feat,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||
|
||||
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||
|
||||
# chaiNNer specific stuff
|
||||
self.model_arch = "OmniSR"
|
||||
self.sub_type = "SR"
|
||||
self.in_nc = num_in_ch
|
||||
self.out_nc = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.scale = up_scale
|
||||
|
||||
self.supports_fp16 = True # TODO: Test this
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = 16
|
||||
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
# import pdb; pdb.set_trace()
|
||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2:]
|
||||
x = self.check_image_size(x)
|
||||
|
||||
residual = self.input(x)
|
||||
out = self.residual_layer(residual)
|
||||
|
||||
# origin
|
||||
out = torch.add(self.output(out), residual)
|
||||
out = self.up(out)
|
||||
|
||||
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||
return out
|
@ -1,294 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: esa.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
|
||||
def moment(x, dim=(2, 3), k=2):
|
||||
assert len(x.size()) == 4
|
||||
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
||||
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
||||
return mk
|
||||
|
||||
|
||||
class ESA(nn.Module):
|
||||
"""
|
||||
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
||||
`Residual Feature Aggregation Network for Image Super-Resolution`
|
||||
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
||||
are deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
||||
super(ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
||||
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
c1 = self.conv2(c1_)
|
||||
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
||||
c3 = self.conv3(v_max)
|
||||
c3 = F.interpolate(
|
||||
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
||||
)
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(c3 + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA_LN(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA_LN, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.norm = LayerNorm2d(n_feats)
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.norm(x)
|
||||
c1_ = self.conv1(c1_)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaGuidedFilter, self).__init__()
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=n_feats,
|
||||
out_channels=1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.r = 5
|
||||
|
||||
def box_filter(self, x, r):
|
||||
channel = x.shape[1]
|
||||
kernel_size = 2 * r + 1
|
||||
weight = 1.0 / (kernel_size**2)
|
||||
box_kernel = weight * torch.ones(
|
||||
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
||||
)
|
||||
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
N = self.box_filter(
|
||||
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
||||
)
|
||||
|
||||
# epsilon = self.fc(self.gap(x))
|
||||
# epsilon = torch.pow(epsilon, 2)
|
||||
epsilon = 1e-2
|
||||
|
||||
mean_x = self.box_filter(x, self.r) / N
|
||||
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
||||
|
||||
A = var_x / (var_x + epsilon)
|
||||
b = (1 - A) * mean_x
|
||||
m = A * x + b
|
||||
|
||||
# mean_A = self.box_filter(A, self.r) / N
|
||||
# mean_b = self.box_filter(b, self.r) / N
|
||||
# m = mean_A * x + mean_b
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaConvGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaConvGuidedFilter, self).__init__()
|
||||
f = esa_channels
|
||||
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.vec_conv(x)
|
||||
y = self.hor_conv(y)
|
||||
|
||||
sigma = torch.pow(y, 2)
|
||||
epsilon = self.fc(self.gap(y))
|
||||
|
||||
weight = sigma / (sigma + epsilon)
|
||||
|
||||
m = weight * x + (1 - weight)
|
||||
|
||||
return x * m
|
@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: layernorm.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
ctx.eps = eps
|
||||
N, C, H, W = x.size()
|
||||
mu = x.mean(1, keepdim=True)
|
||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||
y = (x - mu) / (var + eps).sqrt()
|
||||
ctx.save_for_backward(y, var, weight)
|
||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
eps = ctx.eps
|
||||
|
||||
N, C, H, W = grad_output.size()
|
||||
y, var, weight = ctx.saved_variables
|
||||
g = grad_output * weight.view(1, C, 1, 1)
|
||||
mean_g = g.mean(dim=1, keepdim=True)
|
||||
|
||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||
return (
|
||||
gx,
|
||||
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, channels, eps=1e-6):
|
||||
super(LayerNorm2d, self).__init__()
|
||||
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
||||
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
"""GRN (Global Response Normalization) layer"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: pixelshuffle.py
|
||||
# Created Date: Friday July 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def pixelshuffle_block(
|
||||
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||
):
|
||||
"""
|
||||
Upsample features according to `upscale_factor`.
|
||||
"""
|
||||
padding = kernel_size // 2
|
||||
conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * (upscale_factor**2),
|
||||
kernel_size,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
return nn.Sequential(*[conv, pixel_shuffle])
|
@ -1,296 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import block as B
|
||||
|
||||
|
||||
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
|
||||
# Which enhanced stuff that was already here
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
norm=None,
|
||||
act: str = "leakyrelu",
|
||||
upsampler: str = "upconv",
|
||||
mode: B.ConvMode = "CNA",
|
||||
) -> None:
|
||||
"""
|
||||
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
|
||||
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
|
||||
and Chen Change Loy.
|
||||
This is old-arch Residual in Residual Dense Block Network and is not
|
||||
the newest revision that's available at github.com/xinntao/ESRGAN.
|
||||
This is on purpose, the newest Network has severely limited the
|
||||
potential use of the Network with no benefits.
|
||||
This network supports model files from both new and old-arch.
|
||||
Args:
|
||||
norm: Normalization layer
|
||||
act: Activation layer
|
||||
upsampler: Upsample layer. upconv, pixel_shuffle
|
||||
mode: Convolution mode
|
||||
"""
|
||||
super(RRDBNet, self).__init__()
|
||||
self.model_arch = "ESRGAN"
|
||||
self.sub_type = "SR"
|
||||
|
||||
self.state = state_dict
|
||||
self.norm = norm
|
||||
self.act = act
|
||||
self.upsampler = upsampler
|
||||
self.mode = mode
|
||||
|
||||
self.state_map = {
|
||||
# currently supports old, new, and newer RRDBNet arch models
|
||||
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
|
||||
"model.0.weight": ("conv_first.weight",),
|
||||
"model.0.bias": ("conv_first.bias",),
|
||||
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
|
||||
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
|
||||
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
|
||||
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
|
||||
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
|
||||
),
|
||||
}
|
||||
if "params_ema" in self.state:
|
||||
self.state = self.state["params_ema"]
|
||||
# self.model_arch = "RealESRGAN"
|
||||
self.num_blocks = self.get_num_blocks()
|
||||
self.plus = any("conv1x1" in k for k in self.state.keys())
|
||||
if self.plus:
|
||||
self.model_arch = "ESRGAN+"
|
||||
|
||||
self.state = self.new_to_old_arch(self.state)
|
||||
|
||||
self.key_arr = list(self.state.keys())
|
||||
|
||||
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
|
||||
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
|
||||
|
||||
self.scale: int = self.get_scale()
|
||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||
|
||||
c2x2 = False
|
||||
if self.state["model.0.weight"].shape[-2] == 2:
|
||||
c2x2 = True
|
||||
self.scale = round(math.sqrt(self.scale / 4))
|
||||
self.model_arch = "ESRGAN-2c2"
|
||||
|
||||
self.supports_fp16 = True
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = None
|
||||
|
||||
# Detect if pixelunshuffle was used (Real-ESRGAN)
|
||||
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
|
||||
self.in_nc / 4,
|
||||
self.in_nc / 16,
|
||||
):
|
||||
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
|
||||
else:
|
||||
self.shuffle_factor = None
|
||||
|
||||
upsample_block = {
|
||||
"upconv": B.upconv_block,
|
||||
"pixel_shuffle": B.pixelshuffle_block,
|
||||
}.get(self.upsampler)
|
||||
if upsample_block is None:
|
||||
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
|
||||
|
||||
if self.scale == 3:
|
||||
upsample_blocks = upsample_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
upscale_factor=3,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
else:
|
||||
upsample_blocks = [
|
||||
upsample_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(int(math.log(self.scale, 2)))
|
||||
]
|
||||
|
||||
self.model = B.sequential(
|
||||
# fea conv
|
||||
B.conv_block(
|
||||
in_nc=self.in_nc,
|
||||
out_nc=self.num_filters,
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
B.ShortcutBlock(
|
||||
B.sequential(
|
||||
# rrdb blocks
|
||||
*[
|
||||
B.RRDB(
|
||||
nf=self.num_filters,
|
||||
kernel_size=3,
|
||||
gc=32,
|
||||
stride=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type=self.norm,
|
||||
act_type=self.act,
|
||||
mode="CNA",
|
||||
plus=self.plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(self.num_blocks)
|
||||
],
|
||||
# lr conv
|
||||
B.conv_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
kernel_size=3,
|
||||
norm_type=self.norm,
|
||||
act_type=None,
|
||||
mode=self.mode,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
),
|
||||
*upsample_blocks,
|
||||
# hr_conv0
|
||||
B.conv_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
# hr_conv1
|
||||
B.conv_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.out_nc,
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
|
||||
# Adjust these properties for calculations outside of the model
|
||||
if self.shuffle_factor:
|
||||
self.in_nc //= self.shuffle_factor**2
|
||||
self.scale //= self.shuffle_factor
|
||||
|
||||
self.load_state_dict(self.state, strict=False)
|
||||
|
||||
def new_to_old_arch(self, state):
|
||||
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
|
||||
if "params_ema" in state:
|
||||
state = state["params_ema"]
|
||||
|
||||
if "conv_first.weight" not in state:
|
||||
# model is already old arch, this is a loose check, but should be sufficient
|
||||
return state
|
||||
|
||||
# add nb to state keys
|
||||
for kind in ("weight", "bias"):
|
||||
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
|
||||
f"model.1.sub./NB/.{kind}"
|
||||
]
|
||||
del self.state_map[f"model.1.sub./NB/.{kind}"]
|
||||
|
||||
old_state = OrderedDict()
|
||||
for old_key, new_keys in self.state_map.items():
|
||||
for new_key in new_keys:
|
||||
if r"\1" in old_key:
|
||||
for k, v in state.items():
|
||||
sub = re.sub(new_key, old_key, k)
|
||||
if sub != k:
|
||||
old_state[sub] = v
|
||||
else:
|
||||
if new_key in state:
|
||||
old_state[old_key] = state[new_key]
|
||||
|
||||
# upconv layers
|
||||
max_upconv = 0
|
||||
for key in state.keys():
|
||||
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
|
||||
if match is not None:
|
||||
_, key_num, key_type = match.groups()
|
||||
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
|
||||
max_upconv = max(max_upconv, int(key_num) * 3)
|
||||
|
||||
# final layers
|
||||
for key in state.keys():
|
||||
if key in ("HRconv.weight", "conv_hr.weight"):
|
||||
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
|
||||
elif key in ("HRconv.bias", "conv_hr.bias"):
|
||||
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
|
||||
elif key in ("conv_last.weight",):
|
||||
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
|
||||
elif key in ("conv_last.bias",):
|
||||
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
|
||||
|
||||
# Sort by first numeric value of each layer
|
||||
def compare(item1, item2):
|
||||
parts1 = item1.split(".")
|
||||
parts2 = item2.split(".")
|
||||
int1 = int(parts1[1])
|
||||
int2 = int(parts2[1])
|
||||
return int1 - int2
|
||||
|
||||
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
|
||||
|
||||
# Rebuild the output dict in the right order
|
||||
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
|
||||
|
||||
return out_dict
|
||||
|
||||
def get_scale(self, min_part: int = 6) -> int:
|
||||
n = 0
|
||||
for part in list(self.state):
|
||||
parts = part.split(".")[1:]
|
||||
if len(parts) == 2:
|
||||
part_num = int(parts[0])
|
||||
if part_num > min_part and parts[1] == "weight":
|
||||
n += 1
|
||||
return 2**n
|
||||
|
||||
def get_num_blocks(self) -> int:
|
||||
nbs = []
|
||||
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
|
||||
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
|
||||
)
|
||||
for state_key in state_keys:
|
||||
for k in self.state:
|
||||
m = re.search(state_key, k)
|
||||
if m:
|
||||
nbs.append(int(m.group(1)))
|
||||
if nbs:
|
||||
break
|
||||
return max(*nbs) + 1
|
||||
|
||||
def forward(self, x):
|
||||
if self.shuffle_factor:
|
||||
_, _, h, w = x.size()
|
||||
mod_pad_h = (
|
||||
self.shuffle_factor - h % self.shuffle_factor
|
||||
) % self.shuffle_factor
|
||||
mod_pad_w = (
|
||||
self.shuffle_factor - w % self.shuffle_factor
|
||||
) % self.shuffle_factor
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
||||
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
|
||||
x = self.model(x)
|
||||
return x[:, :, : h * self.scale, : w * self.scale]
|
||||
return self.model(x)
|
@ -1,455 +0,0 @@
|
||||
# pylint: skip-file
|
||||
# -----------------------------------------------------------------------------------
|
||||
# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278
|
||||
# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc
|
||||
# -----------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from .timm.drop import DropPath
|
||||
from .timm.weight_init import trunc_normal_
|
||||
|
||||
|
||||
# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
|
||||
class WMSA(nn.Module):
|
||||
"""Self-attention module in Swin Transformer"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
||||
super(WMSA, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.head_dim = head_dim
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.n_heads = input_dim // head_dim
|
||||
self.window_size = window_size
|
||||
self.type = type
|
||||
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
||||
|
||||
self.relative_position_params = nn.Parameter(
|
||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
|
||||
)
|
||||
# TODO recover
|
||||
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
|
||||
self.relative_position_params = nn.Parameter(
|
||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
|
||||
)
|
||||
|
||||
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
||||
|
||||
trunc_normal_(self.relative_position_params, std=0.02)
|
||||
self.relative_position_params = torch.nn.Parameter(
|
||||
self.relative_position_params.view(
|
||||
2 * window_size - 1, 2 * window_size - 1, self.n_heads
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
def generate_mask(self, h, w, p, shift):
|
||||
"""generating the mask of SW-MSA
|
||||
Args:
|
||||
shift: shift parameters in CyclicShift.
|
||||
Returns:
|
||||
attn_mask: should be (1 1 w p p),
|
||||
"""
|
||||
# supporting square.
|
||||
attn_mask = torch.zeros(
|
||||
h,
|
||||
w,
|
||||
p,
|
||||
p,
|
||||
p,
|
||||
p,
|
||||
dtype=torch.bool,
|
||||
device=self.relative_position_params.device,
|
||||
)
|
||||
if self.type == "W":
|
||||
return attn_mask
|
||||
|
||||
s = p - shift
|
||||
attn_mask[-1, :, :s, :, s:, :] = True
|
||||
attn_mask[-1, :, s:, :, :s, :] = True
|
||||
attn_mask[:, -1, :, :s, :, s:] = True
|
||||
attn_mask[:, -1, :, s:, :, :s] = True
|
||||
attn_mask = rearrange(
|
||||
attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
|
||||
)
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass of Window Multi-head Self-attention module.
|
||||
Args:
|
||||
x: input tensor with shape of [b h w c];
|
||||
attn_mask: attention mask, fill -inf where the value is True;
|
||||
Returns:
|
||||
output: tensor shape [b h w c]
|
||||
"""
|
||||
if self.type != "W":
|
||||
x = torch.roll(
|
||||
x,
|
||||
shifts=(-(self.window_size // 2), -(self.window_size // 2)),
|
||||
dims=(1, 2),
|
||||
)
|
||||
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
|
||||
p1=self.window_size,
|
||||
p2=self.window_size,
|
||||
)
|
||||
h_windows = x.size(1)
|
||||
w_windows = x.size(2)
|
||||
# square validation
|
||||
# assert h_windows == w_windows
|
||||
|
||||
x = rearrange(
|
||||
x,
|
||||
"b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
|
||||
p1=self.window_size,
|
||||
p2=self.window_size,
|
||||
)
|
||||
qkv = self.embedding_layer(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
|
||||
).chunk(3, dim=0)
|
||||
sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
|
||||
# Adding learnable relative embedding
|
||||
sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
|
||||
# Using Attn Mask to distinguish different subwindows.
|
||||
if self.type != "W":
|
||||
attn_mask = self.generate_mask(
|
||||
h_windows, w_windows, self.window_size, shift=self.window_size // 2
|
||||
)
|
||||
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
||||
|
||||
probs = nn.functional.softmax(sim, dim=-1)
|
||||
output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
|
||||
output = rearrange(output, "h b w p c -> b w p (h c)")
|
||||
output = self.linear(output)
|
||||
output = rearrange(
|
||||
output,
|
||||
"b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
|
||||
w1=h_windows,
|
||||
p1=self.window_size,
|
||||
)
|
||||
|
||||
if self.type != "W":
|
||||
output = torch.roll(
|
||||
output,
|
||||
shifts=(self.window_size // 2, self.window_size // 2),
|
||||
dims=(1, 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def relative_embedding(self):
|
||||
cord = torch.tensor(
|
||||
np.array(
|
||||
[
|
||||
[i, j]
|
||||
for i in range(self.window_size)
|
||||
for j in range(self.window_size)
|
||||
]
|
||||
)
|
||||
)
|
||||
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
||||
# negative is allowed
|
||||
return self.relative_position_params[
|
||||
:, relation[:, :, 0].long(), relation[:, :, 1].long()
|
||||
]
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
head_dim,
|
||||
window_size,
|
||||
drop_path,
|
||||
type="W",
|
||||
input_resolution=None,
|
||||
):
|
||||
"""SwinTransformer Block"""
|
||||
super(Block, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
assert type in ["W", "SW"]
|
||||
self.type = type
|
||||
if input_resolution <= window_size:
|
||||
self.type = "W"
|
||||
|
||||
self.ln1 = nn.LayerNorm(input_dim)
|
||||
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.ln2 = nn.LayerNorm(input_dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, 4 * input_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * input_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.msa(self.ln1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConvTransBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_dim,
|
||||
trans_dim,
|
||||
head_dim,
|
||||
window_size,
|
||||
drop_path,
|
||||
type="W",
|
||||
input_resolution=None,
|
||||
):
|
||||
"""SwinTransformer and Conv Block"""
|
||||
super(ConvTransBlock, self).__init__()
|
||||
self.conv_dim = conv_dim
|
||||
self.trans_dim = trans_dim
|
||||
self.head_dim = head_dim
|
||||
self.window_size = window_size
|
||||
self.drop_path = drop_path
|
||||
self.type = type
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
assert self.type in ["W", "SW"]
|
||||
if self.input_resolution <= self.window_size:
|
||||
self.type = "W"
|
||||
|
||||
self.trans_block = Block(
|
||||
self.trans_dim,
|
||||
self.trans_dim,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
self.drop_path,
|
||||
self.type,
|
||||
self.input_resolution,
|
||||
)
|
||||
self.conv1_1 = nn.Conv2d(
|
||||
self.conv_dim + self.trans_dim,
|
||||
self.conv_dim + self.trans_dim,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
bias=True,
|
||||
)
|
||||
self.conv1_2 = nn.Conv2d(
|
||||
self.conv_dim + self.trans_dim,
|
||||
self.conv_dim + self.trans_dim,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
conv_x, trans_x = torch.split(
|
||||
self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
|
||||
)
|
||||
conv_x = self.conv_block(conv_x) + conv_x
|
||||
trans_x = Rearrange("b c h w -> b h w c")(trans_x)
|
||||
trans_x = self.trans_block(trans_x)
|
||||
trans_x = Rearrange("b h w c -> b c h w")(trans_x)
|
||||
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
||||
x = x + res
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SCUNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
in_nc=3,
|
||||
config=[4, 4, 4, 4, 4, 4, 4],
|
||||
dim=64,
|
||||
drop_path_rate=0.0,
|
||||
input_resolution=256,
|
||||
):
|
||||
super(SCUNet, self).__init__()
|
||||
self.model_arch = "SCUNet"
|
||||
self.sub_type = "SR"
|
||||
|
||||
self.num_filters: int = 0
|
||||
|
||||
self.state = state_dict
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.head_dim = 32
|
||||
self.window_size = 8
|
||||
|
||||
self.in_nc = in_nc
|
||||
self.out_nc = self.in_nc
|
||||
self.scale = 1
|
||||
self.supports_fp16 = True
|
||||
|
||||
# drop path rate for each layer
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
||||
|
||||
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
||||
|
||||
begin = 0
|
||||
self.m_down1 = [
|
||||
ConvTransBlock(
|
||||
dim // 2,
|
||||
dim // 2,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution,
|
||||
)
|
||||
for i in range(config[0])
|
||||
] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[0]
|
||||
self.m_down2 = [
|
||||
ConvTransBlock(
|
||||
dim,
|
||||
dim,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution // 2,
|
||||
)
|
||||
for i in range(config[1])
|
||||
] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[1]
|
||||
self.m_down3 = [
|
||||
ConvTransBlock(
|
||||
2 * dim,
|
||||
2 * dim,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution // 4,
|
||||
)
|
||||
for i in range(config[2])
|
||||
] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[2]
|
||||
self.m_body = [
|
||||
ConvTransBlock(
|
||||
4 * dim,
|
||||
4 * dim,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution // 8,
|
||||
)
|
||||
for i in range(config[3])
|
||||
]
|
||||
|
||||
begin += config[3]
|
||||
self.m_up3 = [
|
||||
nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False),
|
||||
] + [
|
||||
ConvTransBlock(
|
||||
2 * dim,
|
||||
2 * dim,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution // 4,
|
||||
)
|
||||
for i in range(config[4])
|
||||
]
|
||||
|
||||
begin += config[4]
|
||||
self.m_up2 = [
|
||||
nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False),
|
||||
] + [
|
||||
ConvTransBlock(
|
||||
dim,
|
||||
dim,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution // 2,
|
||||
)
|
||||
for i in range(config[5])
|
||||
]
|
||||
|
||||
begin += config[5]
|
||||
self.m_up1 = [
|
||||
nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False),
|
||||
] + [
|
||||
ConvTransBlock(
|
||||
dim // 2,
|
||||
dim // 2,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
dpr[i + begin],
|
||||
"W" if not i % 2 else "SW",
|
||||
input_resolution,
|
||||
)
|
||||
for i in range(config[6])
|
||||
]
|
||||
|
||||
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
||||
|
||||
self.m_head = nn.Sequential(*self.m_head)
|
||||
self.m_down1 = nn.Sequential(*self.m_down1)
|
||||
self.m_down2 = nn.Sequential(*self.m_down2)
|
||||
self.m_down3 = nn.Sequential(*self.m_down3)
|
||||
self.m_body = nn.Sequential(*self.m_body)
|
||||
self.m_up3 = nn.Sequential(*self.m_up3)
|
||||
self.m_up2 = nn.Sequential(*self.m_up2)
|
||||
self.m_up1 = nn.Sequential(*self.m_up1)
|
||||
self.m_tail = nn.Sequential(*self.m_tail)
|
||||
# self.apply(self._init_weights)
|
||||
self.load_state_dict(state_dict, strict=True)
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
mod_pad_h = (64 - h % 64) % 64
|
||||
mod_pad_w = (64 - w % 64) % 64
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
||||
return x
|
||||
|
||||
def forward(self, x0):
|
||||
h, w = x0.size()[-2:]
|
||||
x0 = self.check_image_size(x0)
|
||||
|
||||
x1 = self.m_head(x0)
|
||||
x2 = self.m_down1(x1)
|
||||
x3 = self.m_down2(x2)
|
||||
x4 = self.m_down3(x3)
|
||||
x = self.m_body(x4)
|
||||
x = self.m_up3(x + x4)
|
||||
x = self.m_up2(x + x3)
|
||||
x = self.m_up1(x + x2)
|
||||
x = self.m_tail(x + x1)
|
||||
|
||||
x = x[:, :, :h, :w]
|
||||
return x
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user