Compare commits

..

129 Commits

Author SHA1 Message Date
comfyanonymous
9ad792f927 Basic support for hidream i1 model. 2025-04-15 17:35:05 -04:00
comfyanonymous
6fc5dbd52a Cleanup. 2025-04-15 12:13:28 -04:00
comfyanonymous
3e8155f7a3 More flexible long clip support.
Add clip g long clip support.

Text encoder refactor.

Support llama models with different vocab sizes.
2025-04-15 10:32:21 -04:00
comfyanonymous
8a438115fb add RMSNorm to comfy.ops 2025-04-14 18:00:33 -04:00
comfyanonymous
a14c2fc356 ComfyUI version v0.3.28 2025-04-13 12:21:12 -07:00
JNP
9ee6ca99d8
add_optimalsteps (#7584)
Co-authored-by: bebebe666 <jianningpei@tencent.com>
2025-04-12 20:33:36 -04:00
comfyanonymous
bb495cc9b8 Print python version in log. 2025-04-12 18:58:34 -04:00
chaObserv
e51d9ba5fc
Add SEEDS (stage 2 & 3 DP) sampler (#7580)
* Add seeds stage 2 & 3 (DP) sampler

* Change the name to SEEDS in comment
2025-04-12 18:36:08 -04:00
Christian Byrne
c87a06f934
Update filter_files_content_types to support filtering 3d models (#7572)
* support 3d model filtering

* fix lint error: blank line contains whitespace

* add model extensions to test runner mimetype cache manually

* use unittest.mock.patch

* remove mtl file from testcase (actually plaintext support file)
2025-04-12 18:30:39 -04:00
catboxanon
1714a4c158
Add CublasOps support (#7574)
* CublasOps support

* Guard CublasOps behind --fast arg
2025-04-12 18:29:15 -04:00
Christian Byrne
73ecb75a3d
filter image files in load image dropdown (#7573) 2025-04-12 18:27:59 -04:00
comfyanonymous
22ad513c72 Refactor node cache code to more easily add other types of cache. 2025-04-11 07:16:52 -04:00
Chargeuk
ed945a1790
Dependency Aware Node Caching for low RAM/VRAM machines (#7509)
* add dependency aware cache that removed a cached node as soon as all of its decendents have executed. This allows users with lower RAM to run workflows they would otherwise not be able to run. The downside is that every workflow will fully run each time even if no nodes have changed.

* remove test code

* tidy code
2025-04-11 06:55:51 -04:00
Chenlei Hu
f9207c6936
Update frontend to 1.15 (#7564) 2025-04-11 06:46:20 -04:00
Christian Byrne
8ad7477647
dont cache templates index (#7569) 2025-04-11 06:06:53 -04:00
Chenlei Hu
98bdca4cb2
Deprecate InputTypeOptions.defaultInput (#7551)
* Deprecate InputTypeOptions.defaultInput

* nit

* nit
2025-04-10 06:57:06 -04:00
comfyanonymous
a26da20a76 Fix custom nodes not importing when path contains a dot. 2025-04-10 03:37:52 -04:00
Jedrzej Kosinski
e346d8584e
Add prepare_sampling wrapper allowing custom nodes to more accurately report noise_shape (#7500) 2025-04-09 09:43:35 -04:00
comfyanonymous
ab31b64412 Make "surface net" the default in the VoxelToMesh node. 2025-04-09 09:42:08 -04:00
thot experiment
fe29739c68
add VoxelToMesh node w/ surfacenet meshing (#7446)
* add VoxelToMesh node w/ surfacenet meshing

could delete the VoxelToMeshBasic node now probably?

* fix ruff
2025-04-09 09:41:03 -04:00
Chenlei Hu
e8345a9b7b
Align /prompt response schema (#7423) 2025-04-09 09:10:36 -04:00
comfyanonymous
8c6b9f4481
Prevent custom nodes from accidentally overwriting global modules. (#7167)
* Prevent custom nodes from accidentally overwriting global modules.

* Improve.
2025-04-09 09:08:57 -04:00
Christian Byrne
cc7e023a4a
handle palette mode in loadimage node (#7539) 2025-04-09 09:07:07 -04:00
comfyanonymous
2f7d8159c3 Show the user an error when the controlnet file is invalid. 2025-04-08 08:11:59 -04:00
comfyanonymous
70d7242e57 Support the wan fun reward loras. 2025-04-07 05:01:47 -04:00
comfyanonymous
49b732afd5 Show a proper error to the user when a vision model file is invalid. 2025-04-06 22:43:56 -04:00
comfyanonymous
3bfe4e5276 Support 512 siglip model. 2025-04-05 07:01:01 -04:00
Raphael Walker
89e4ea0175
Add activations_shape info in UNet models (#7482)
* Add activations_shape info in UNet models

* activations_shape should be a list
2025-04-04 21:27:54 -04:00
comfyanonymous
3a100b9a55 Disable partial offloading of audio VAE. 2025-04-04 21:24:56 -04:00
comfyanonymous
721253cb05 Fix problem. 2025-04-03 20:57:59 -04:00
comfyanonymous
3d2e3a6f29 Fix alpha image issue in more nodes. 2025-04-02 19:32:49 -04:00
BiologicalExplosion
2222cf67fd
MLU memory optimization (#7470)
Co-authored-by: huzhan <huzhan@cambricon.com>
2025-04-02 19:24:04 -04:00
comfyanonymous
ab5413351e Fix comment.
This function does not support quads.
2025-04-01 14:09:31 -04:00
Laurent Erignoux
2b71aab299
User missing (#7439)
* Ensuring a 401 error is returned when user data is not found in multi-user context.

* Returning a 401 error when provided comfy-user does not exists on server side.
2025-04-01 13:53:52 -04:00
BVH
301e26b131
Add option to store TE in bf16 (#7461) 2025-04-01 13:48:53 -04:00
comfyanonymous
548457bac4 Fix alpha channel mismatch on destination in ImageCompositeMasked 2025-03-31 20:59:12 -04:00
comfyanonymous
0b4584c741 Fix latent composite node not working when source has alpha. 2025-03-30 21:47:05 -04:00
comfyanonymous
a3100c8452 Remove useless code. 2025-03-29 20:12:56 -04:00
Michael Kupchick
832fc02330
ltxv: fix preprocessing exception when compression is 0. (#7431) 2025-03-29 20:03:02 -04:00
comfyanonymous
2d17d8910c Don't error if wan concat image has extra channels. 2025-03-28 08:49:29 -04:00
Chenlei Hu
a40fcfc2d5
Update frontend to 1.14.6 (#7416)
Cherry-pick the fix: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3252
2025-03-28 02:27:01 -04:00
comfyanonymous
0a1f8869c9 Add WanFunInpaintToVideo node for the Wan fun inpaint models. 2025-03-27 11:13:27 -04:00
comfyanonymous
3661c833bc Support the WAN 2.1 fun control models.
Use the new WanFunControlToVideo node.
2025-03-26 19:54:54 -04:00
comfyanonymous
84fdaf7b0e Add CFGZeroStar node.
Works on all models that use a negative prompt but is meant for rectified
flow models.
2025-03-26 05:09:52 -04:00
comfyanonymous
8edc1f44c1 Support more float8 types. 2025-03-25 05:23:49 -04:00
comfyanonymous
eade1551bb Add Hunyuan3D to readme. 2025-03-24 07:14:32 -04:00
comfyanonymous
581a9991ff Add model merging node for WAN 2.1 2025-03-23 08:06:36 -04:00
comfyanonymous
e471c726e5 Fallback to pytorch attention if sage attention fails. 2025-03-22 15:45:56 -04:00
comfyanonymous
75c1c757d9 ComfyUI version v0.3.27 2025-03-21 20:09:54 -04:00
Chenlei Hu
ce9b084279
[nit] Format error strings (#7345) 2025-03-21 19:08:25 -04:00
Terry Jia
2206246055
support output normal and lineart once (#7290) 2025-03-21 16:24:13 -04:00
comfyanonymous
d9fa9d307f Automatically set the right sampling type for lotus. 2025-03-21 14:19:37 -04:00
thot experiment
83e839a89b
Native LotusD Implementation (#7125)
* draft pass at a native comfy implementation of Lotus-D depth and normal est

* fix model_sampling kludges

* fix ruff

---------

Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
2025-03-21 14:04:15 -04:00
Chenlei Hu
0cf2274699
Update frontend to 1.14 (#7343) 2025-03-21 13:50:09 -04:00
comfyanonymous
0956107170 Nodes to convert images to YUV and back.
Can be used to convert an image to black and white.
2025-03-21 06:32:44 -04:00
Chenlei Hu
a4a956dbbd
Add backend primitive nodes (#7328)
* Add backend primitive nodes

* Add control after generate to int primitive
2025-03-21 01:47:18 -04:00
Chenlei Hu
8b9ce4ed18
Update frontend to 1.13 (#7331) 2025-03-21 00:17:36 -04:00
comfyanonymous
3872b43d4b A few fixes for the hunyuan3d models. 2025-03-20 04:52:31 -04:00
comfyanonymous
32ca0805b7 Fix orientation of hunyuan 3d model. 2025-03-19 19:55:24 -04:00
comfyanonymous
11f1b41bab Initial Hunyuan3Dv2 implementation.
Supports the multiview, mini, turbo models and VAEs.
2025-03-19 16:52:58 -04:00
comfyanonymous
3b19fc76e3 Allow disabling pe in flux code for some other models. 2025-03-18 05:09:25 -04:00
comfyanonymous
50614f1b79 Fix regression with clip vision. 2025-03-17 13:56:11 -04:00
comfyanonymous
6dc7b0bfe3 Add support for giant dinov2 image encoder. 2025-03-17 05:53:54 -04:00
comfyanonymous
e8e990d6b8 Cleanup code. 2025-03-16 06:29:12 -04:00
Jedrzej Kosinski
2e24a15905
Call unpatch_hooks at the start of ModelPatcher.partially_unload (#7253)
* Call unpatch_hooks at the start of ModelPatcher.partially_unload

* Only call unpatch_hooks in partially_unload if lowvram is possible
2025-03-16 06:02:45 -04:00
chaObserv
fd5297131f
Guard the edge cases of noise term in er_sde (#7265) 2025-03-16 06:02:25 -04:00
comfyanonymous
55a1b09ddc Allow loading diffusion model files with the "Load Checkpoint" node. 2025-03-15 08:27:49 -04:00
comfyanonymous
3c3988df45 Show a better error message if the VAE is invalid. 2025-03-15 08:26:36 -04:00
Christian Byrne
7ebd8087ff
hotfix fe (#7244) 2025-03-15 01:38:10 -04:00
Chenlei Hu
c624c29d66
Update frontend to 1.12.9 (#7236)
* Update frontend to 1.12.9

* Update requirements.txt
2025-03-14 18:17:26 -04:00
comfyanonymous
a2448fc527 Remove useless code. 2025-03-14 18:10:37 -04:00
comfyanonymous
6a0daa79b6 Make the SkipLayerGuidanceDIT node work on WAN. 2025-03-14 10:55:19 -04:00
FeepingCreature
9c98c6358b
Tolerate missing @torch.library.custom_op (#7234)
This can happen on Pytorch versions older than 2.4.
2025-03-14 09:51:26 -04:00
FeepingCreature
7aceb9f91c
Add --use-flash-attention flag. (#7223)
* Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
2025-03-14 03:22:41 -04:00
comfyanonymous
35504e2f93 Fix. 2025-03-13 15:03:18 -04:00
comfyanonymous
299436cfed Print mac version. 2025-03-13 10:05:40 -04:00
Chenlei Hu
52e566d2bc
Add codeowner for comfy/comfy_types (#7213) 2025-03-12 17:30:00 -04:00
Chenlei Hu
9b6cd9b874
[NodeDef] Add documentation on multi_select input option (#7212) 2025-03-12 17:29:39 -04:00
chaObserv
3fc688aebd
Ensure the extra_args in dpmpp sde series (#7204) 2025-03-12 17:28:59 -04:00
comfyanonymous
f4411250f3 Repeat frontend version warning at the end.
This way someone running ComfyUI with the command line is more likely to
actually see it.
2025-03-12 07:13:40 -04:00
Chenlei Hu
d2a0fb6bb0
Add unwrap widget value support (#7197)
* Add unwrap widget value support

* nit
2025-03-12 06:39:14 -04:00
chaObserv
01015bff16
Add er_sde sampler (#7187) 2025-03-12 02:42:37 -04:00
comfyanonymous
2330754b0e Fix error saving some latents. 2025-03-11 15:07:16 -04:00
comfyanonymous
bc219a6487
Merge pull request #7143 from christian-byrne/fix-remote-widget-node
Fix LoadImageOutput node
2025-03-11 04:30:25 -04:00
comfyanonymous
94689766ad
Merge pull request #7179 from comfyanonymous/ignore_fe_package
Only check frontend package if using default frontend
2025-03-11 03:45:02 -04:00
huchenlei
cfbe4b49ca Access package version 2025-03-10 20:43:59 -04:00
comfyanonymous
ca8efab79f Support control loras on Wan. 2025-03-10 17:23:13 -04:00
Chenlei Hu
65ea778a5e nit 2025-03-10 15:19:59 -04:00
Chenlei Hu
db9f2a34fc Fix unit test 2025-03-10 15:19:52 -04:00
Chenlei Hu
7946049794 nit 2025-03-10 15:14:40 -04:00
Chenlei Hu
6f6349b6a7 nit 2025-03-10 15:10:40 -04:00
Chenlei Hu
1f138dd382 Only check frontend package if using default frontend 2025-03-10 15:07:44 -04:00
comfyanonymous
b779349b55 Temporarily revert fix to give time for people to update their nodes. 2025-03-10 06:30:17 -04:00
comfyanonymous
35e2dcf5d7 Hack to fix broken manager. 2025-03-10 06:15:17 -04:00
Andrew Kvochko
67c7184b74
ltxv: relax frame_idx divisibility for single frames. (#7146)
This commit relaxes divisibility constraint for single-frame
conditionings. For single frames, the index can be arbitrary, while
multi-frame conditionings (>= 9 frames) must still be aligned to 8
frames.

Co-authored-by: Andrew Kvochko <a.kvochko@lightricks.com>
2025-03-10 04:11:48 -04:00
comfyanonymous
6f8e766509 Prevent custom nodes from accidentally overwriting global modules. 2025-03-10 03:33:41 -04:00
Terry Jia
e1da98a14a
remove unused params (#6931) 2025-03-09 14:07:09 -04:00
bymyself
a73410aafa remove overrides 2025-03-09 03:46:08 -07:00
comfyanonymous
9aac21f894 Fix issues with new hunyuan img2vid model and bumb version to v0.3.26 2025-03-09 05:07:22 -04:00
Jedrzej Kosinski
528d1b3563
When cached_hook_patches contain weights for hooks, only use hook_backup for unused keys (#7067) 2025-03-09 04:26:31 -04:00
comfyanonymous
2bc4b5968f ComfyUI version v0.3.25 2025-03-09 03:30:20 -04:00
comfyanonymous
7395b0c0d1 Support new hunyuan video i2v model.
Use the new "v2 (replace)" guidance type in HunyuanImageToVideo and set
image_interleave to 4 on the "Text Encode Hunyuan Video" node.
2025-03-08 20:34:47 -05:00
comfyanonymous
0952569493 Fix stable cascade VAE on some lowvram machines. 2025-03-08 20:24:04 -05:00
comfyanonymous
29832b3b61 Warn if frontend package is older than the one in requirements.txt 2025-03-08 03:51:36 -05:00
comfyanonymous
be4e760648 Add an image_interleave option to the Hunyuan image to video encode node.
See the tooltip for what it does.
2025-03-07 19:56:26 -05:00
comfyanonymous
c3d9cc4592 Print the frontend version in the log. 2025-03-07 19:56:26 -05:00
Chenlei Hu
84cc9cb528
Update frontend to 1.11.8 (#7119)
* Update frontend to 1.11.7

* Update requirements.txt
2025-03-07 19:02:13 -05:00
comfyanonymous
ebbb920163 Add back taesd to nightly package. 2025-03-07 14:56:09 -05:00
comfyanonymous
d60fe0af4a Reduce size of nightly package. 2025-03-07 08:30:01 -05:00
comfyanonymous
5dbd250965 Update nightly instructions in readme. 2025-03-07 07:57:59 -05:00
comfyanonymous
4ab1875283 Add .bat file to nightly package to run with fp16 accumulation. 2025-03-07 07:45:40 -05:00
comfyanonymous
11b1f27cb1 Set WAN default compute dtype to fp16. 2025-03-07 04:52:36 -05:00
comfyanonymous
70e15fd743 No need for scale_input when fp8 matrix mult is disabled. 2025-03-07 04:49:20 -05:00
comfyanonymous
e1474150de Support fp8_scaled diffusion models that don't use fp8 matrix mult. 2025-03-07 04:39:21 -05:00
JettHu
e62d72e8ca
Typo in node_typing.py (#7092) 2025-03-06 15:24:04 -05:00
Dr.Lt.Data
1650cda030
Fixed: Incorrect guide message for missing frontend. (#7105)
`{sys.executable} -m pip` -> `{sys.executable} -s -m pip`

https://github.com/comfyanonymous/ComfyUI/pull/7047#issuecomment-2697876793
2025-03-06 15:23:23 -05:00
comfyanonymous
a13125840c ComfyUI version v0.3.24 2025-03-06 13:53:48 -05:00
comfyanonymous
dfa36e6855 Fix some things breaking when embeddings fail to apply. 2025-03-06 13:31:55 -05:00
comfyanonymous
0124be4d93 ComfyUI version v0.3.23 2025-03-06 04:10:12 -05:00
comfyanonymous
29a70ca101 Support HunyuanVideo image to video model. 2025-03-06 03:07:15 -05:00
comfyanonymous
0bef826a98 Support llava clip vision model. 2025-03-06 00:24:43 -05:00
comfyanonymous
85ef295069 Make applying embeddings more efficient.
Adding new tokens no longer makes a whole copy of the embeddings weight
which can be massive on certain models.
2025-03-05 17:34:38 -05:00
Chenlei Hu
5d84607bf3
Add type hint for FileLocator (#6968)
* Add type hint for FileLocator

* nit
2025-03-05 15:35:26 -05:00
Silver
c1909f350f
Better argument handling of front-end-root (#7043)
* Better argument handling of front-end-root

Improves handling of front-end-root launch argument. Several instances where users have set it and ComfyUI launches as normal and completely disregards the launch arg which doesn't make sense. Better to indicate to user that something is incorrect.

* Removed unused import

There was no real reason to use "Optional" typing in ther front-end-root argument.
2025-03-05 15:34:22 -05:00
Chenlei Hu
52b3469606
[NodeDef] Explicitly add control_after_generate to seed/noise_seed (#7059)
* [NodeDef] Explicitly add control_after_generate to seed/noise_seed

* Update comfy/comfy_types/node_typing.py

Co-authored-by: filtered <176114999+webfiltered@users.noreply.github.com>

---------

Co-authored-by: filtered <176114999+webfiltered@users.noreply.github.com>
2025-03-05 15:33:23 -05:00
comfyanonymous
889519971f Bump ComfyUI version to v0.3.22 2025-03-05 10:06:37 -05:00
comfyanonymous
76739c23c3 Revert "Partially revert last commit."
This reverts commit a80bc822a2.
2025-03-05 09:57:40 -05:00
comfyanonymous
a80bc822a2 Partially revert last commit. 2025-03-05 08:58:44 -05:00
Andrew Kvochko
872780d236
fix: ltxv crop guides works with 0 keyframes (#7085)
This patch fixes a bug in LTXVCropGuides when the latent has no
keyframes. Additionally, the first frame is always added as a keyframe.

Co-authored-by: Andrew Kvochko <a.kvochko@lightricks.com>
2025-03-05 08:47:32 -05:00
94 changed files with 4722 additions and 415 deletions

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
pause

View File

@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "126"
default: "128"
python_minor:
description: 'python minor version'
@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "1"
default: "2"
# push:
# branches:
# - master
@ -34,7 +34,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-depth: 30
persist-credentials: false
- uses: actions/setup-python@v5
with:
@ -74,7 +74,7 @@ jobs:
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
cd ComfyUI_windows_portable_nightly_pytorch

View File

@ -19,5 +19,6 @@
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered

View File

@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
@ -215,9 +217,9 @@ Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
This is the command to install pytorch nightly instead which might have performance improvements:
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
#### Troubleshooting

View File

@ -9,8 +9,14 @@ class AppSettings():
self.user_manager = user_manager
def get_settings(self, request):
try:
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
request,
"comfy.settings.json"
)
except KeyError as e:
logging.error("User settings not found.")
raise web.HTTPUnauthorized() from e
if os.path.isfile(file):
try:
with open(file) as f:

View File

@ -11,20 +11,61 @@ from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict, Optional
from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
try:
import comfyui_frontend_package
except ImportError:
# TODO: Remove the check after roll out of 0.3.16
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n")
exit(-1)
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
def check_frontend_version():
"""Check if the frontend version is up to date."""
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
{frontend_install_warning_message()}
________________________________________________________________________
""".strip()
)
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
REQUEST_TIMEOUT = 10 # seconds
@ -121,9 +162,28 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-frontend-package is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
sys.exit(-1)
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
@ -144,7 +204,9 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
def init_frontend_unsafe(
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initializes the frontend for the specified version.
@ -160,17 +222,26 @@ class FrontendManager:
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH
check_frontend_version()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string)
if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
expected_path = str(
Path(cls.CUSTOM_FRONTENDS_ROOT)
/ f"{repo_owner}_{repo_name}"
/ version.lstrip("v")
)
if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
logging.info(
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
)
return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
logging.info(
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
)
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
@ -213,4 +284,5 @@ class FrontendManager:
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH
check_frontend_version()
return cls.default_frontend_path()

View File

@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger.addHandler(stdout_handler)
logger.addHandler(stream_handler)
STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
STARTUP_WARNINGS.clear()

View File

@ -1,7 +1,6 @@
import argparse
import enum
import os
from typing import Optional
import comfy.options
@ -80,6 +79,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
@ -101,12 +101,14 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@ -134,8 +136,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -166,13 +169,14 @@ parser.add_argument(
""",
)
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path):
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
if not os.access(path, os.R_OK):
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
return path
parser.add_argument(

View File

@ -97,8 +97,12 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@ -116,6 +120,9 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output
@ -204,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@ -213,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else:
self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)
projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)

View File

@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
@ -42,10 +49,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@ -65,6 +73,7 @@ class ClipVisionModel():
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs
def convert_to_transformers(sd, prefix):
@ -101,12 +110,21 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else:
return None

View File

@ -0,0 +1,19 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}

View File

@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 512,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -1,6 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
class UnetApplyFunction(Protocol):
@ -42,4 +42,5 @@ __all__ = [
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
FileLocator.__name__,
]

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Literal, TypedDict
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
@ -26,6 +27,7 @@ class IO(StrEnum):
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
@ -66,6 +68,7 @@ class IO(StrEnum):
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict):
route: str
"""The route to the remote source."""
@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
@ -91,9 +102,13 @@ class InputTypeOptions(TypedDict):
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: bool
@ -114,7 +129,7 @@ class InputTypeOptions(TypedDict):
# default: bool
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str
label_off: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
@ -133,7 +148,22 @@ class InputTypeOptions(TypedDict):
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: RemoteInputOptions
"""Specifies the configuration for a remote input."""
"""Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict):
@ -293,3 +323,14 @@ class CheckLazyMixin:
need = [name for name in kwargs if kwargs[name] is None]
return need
class FileLocator(TypedDict):
"""Provides type hinting for the file location"""
filename: str
"""The filename of the file."""
subfolder: str
"""The subfolder of the file."""
type: Literal["input", "output", "temp"]
"""The root folder of the file."""

View File

@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None
@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
@ -1366,3 +1366,157 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
old_denoised = None
old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s = t + r * h
fac = 1 / (2 * r)
sigma_s = s.neg().exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x

View File

@ -456,3 +456,13 @@ class Wan21(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404

View File

@ -19,6 +19,10 @@
import torch
from torch import nn
from torch.autograd import Function
import comfy.ops
ops = comfy.ops.disable_weight_init
class vector_quantize(Function):
@staticmethod
@ -121,15 +125,15 @@ class ResBlock(nn.Module):
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(c, c, kernel_size=3, groups=c)
ops.Conv2d(c, c, kernel_size=3, groups=c)
)
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden),
ops.Linear(c, c_hidden),
nn.GELU(),
nn.Linear(c_hidden, c),
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
@ -171,16 +175,16 @@ class StageA(nn.Module):
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
))
self.down_blocks = nn.Sequential(*down_blocks)
@ -191,7 +195,7 @@ class StageA(nn.Module):
# Decoder blocks
up_blocks = [nn.Sequential(
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
@ -199,11 +203,11 @@ class StageA(nn.Module):
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1))
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
@ -232,17 +236,17 @@ class Discriminator(nn.Module):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):

View File

@ -19,6 +19,9 @@ import torch
import torchvision
from torch import nn
import comfy.ops
ops = comfy.ops.disable_weight_init
# EfficientNet
class EfficientNetEncoder(nn.Module):
@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x):
x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
o = self.mapper(self.backbone(x))
return o
@ -44,39 +47,39 @@ class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):

View File

@ -1,5 +1,6 @@
import torch
import comfy.ops
import comfy.rmsnorm
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
return torch.nn.functional.pad(img, pad, mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
rms_norm = comfy.rmsnorm.rms_norm

View File

@ -105,7 +105,9 @@ class Modulation(nn.Module):
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
if vec.ndim == 2:
vec = vec[:, None, :]
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
@ -113,6 +115,20 @@ class Modulation(nn.Module):
)
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
else:
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
return tensor
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@ -252,8 +268,11 @@ class LastLayer(nn.Module):
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2:
vec = vec[:, None, :]
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
x = self.linear(x)
return x

View File

@ -10,8 +10,9 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -115,8 +115,11 @@ class Flux(nn.Module):
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):

828
comfy/ldm/hidream/model.py Normal file
View File

@ -0,0 +1,828 @@
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import einops
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
import torch.nn.functional as F
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
dtype=None, device=None, operations=None
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, latent):
latent = self.proj(latent)
return latent
class PooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, pooled_embed):
return self.pooled_embedder(pooled_embed)
class TimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, timesteps, wdtype):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
class OutEmbed(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, adaln_input):
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = self.linear(x)
return x
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
class HiDreamAttnProcessor_flashattn:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
dtype = image_tokens.dtype
batch_size = image_tokens.shape[0]
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
value_i = attn.to_v(image_tokens)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if image_tokens_masks is not None:
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
value_t = attn.to_v_t(text_tokens)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == rope.shape[-3] * 2:
query, key = apply_rope(query, key, rope)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, rope)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states
class HiDreamAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor = None,
out_dim: int = None,
single: bool = False,
dtype=None, device=None, operations=None
):
# super(Attention, self).__init__()
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
linear_cls = operations.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
if not single:
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.processor = processor
def forward(
self,
norm_image_tokens: torch.FloatTensor,
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.Tensor:
return self.processor(
self,
image_tokens = norm_image_tokens,
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
)
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
dtype=None, device=None, operations=None
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * (
(hidden_dim + multiple_of - 1) // multiple_of
)
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = 'softmax'
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
self.reset_parameters()
def reset_parameters(self) -> None:
pass
# import torch.nn.init as init
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
dtype=None, device=None, operations=None
):
super().__init__()
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
self.gate = MoEGate(
embed_dim = dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if True: # self.training: # TODO: check which branch performs faster
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
#y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
class BlockType:
TransformerBlock = 1
SingleTransformerBlock = 2
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = True,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_image_tokens,
image_tokens_masks,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
image_tokens = ff_output_i + image_tokens
return image_tokens
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
)
# nn.init.zeros_(self.adaLN_modulation[1].weight)
# nn.init.zeros_(self.adaLN_modulation[1].bias)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = False,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_image_tokens,
image_tokens_masks,
norm_text_tokens,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
text_tokens = gate_msa_t * attn_output_t + text_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
image_tokens = ff_output_i + image_tokens
text_tokens = ff_output_t + text_tokens
return image_tokens, text_tokens
class HiDreamImageBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
block_type: BlockType = BlockType.TransformerBlock,
dtype=None, device=None, operations=None
):
super().__init__()
block_classes = {
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
}
self.block = block_classes[block_type](
dim,
num_attention_heads,
attention_head_dim,
num_routed_experts,
num_activated_experts,
dtype=dtype, device=device, operations=operations
)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
return self.block(
image_tokens,
image_tokens_masks,
text_tokens,
adaln_input,
rope,
)
class HiDreamImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
image_model=None,
dtype=None, device=None, operations=None
):
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.num_layers = num_layers
self.num_single_layers = num_single_layers
self.gradient_checkpointing = False
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.inner_dim = self.num_attention_heads * self.attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.x_embedder = PatchEmbed(
patch_size = patch_size,
in_channels = in_channels,
out_channels = self.inner_dim,
dtype=dtype, device=device, operations=operations
)
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.TransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.SingleTransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_single_layers)
]
)
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.patch_size, p2=self.patch_size)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.patch_size * self.patch_size
if isinstance(x, torch.Tensor):
B = x.shape[0]
device = x.device
dtype = x.dtype
else:
B = len(x)
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0:img_size[0] * img_size[1]] = 1
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
elif isinstance(x, torch.Tensor):
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
control = None,
transformer_options = {},
) -> torch.Tensor:
hidden_states = x
timesteps = t
pooled_embeds = y
T5_encoder_hidden_states = context
img_sizes = None
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
adaln_input = timesteps + p_embedder
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if image_tokens_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
hidden_states = self.x_embedder(hidden_states)
# T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
3,
device=img_ids.device, dtype=img_ids.dtype
)
ids = torch.cat((img_ids, txt_ids), dim=1)
rope = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
image_tokens_masks = image_tokens_masks,
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if image_tokens_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
)
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states = block(
image_tokens=hidden_states,
image_tokens_masks=image_tokens_masks,
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, adaln_input)
output = self.unpatchify(output, img_sizes)
return -output

View File

@ -0,0 +1,135 @@
import torch
from torch import nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
class Hunyuan3Dv2(nn.Module):
def __init__(
self,
in_channels=64,
context_in_dim=1536,
hidden_size=1024,
mlp_ratio=4.0,
num_heads=16,
depth=16,
depth_single_blocks=32,
qkv_bias=True,
guidance_embed=False,
image_model=None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dtype = dtype
if hidden_size % num_heads != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
)
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
)
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
dtype=dtype, device=device, operations=operations
)
for _ in range(depth_single_blocks)
]
)
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
img = self.latent_in(x)
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
if self.guidance_in is not None:
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
txt = self.cond_in(txt)
pe = None
attn_mask = None
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)
return img.movedim(-2, -1) * (-1.0)

587
comfy/ldm/hunyuan3d/vae.py Normal file
View File

@ -0,0 +1,587 @@
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple, List, Callable, Optional
import numpy as np
from einops import repeat, rearrange
from tqdm import tqdm
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length
class VanillaVolumeDecoder:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
], here f_i is the frequency.
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
Args:
num_freqs (int): the number of frequencies, default is 6;
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
input_dim (int): the input dimension, default is 3;
include_input (bool): include the input tensor or not, default is True.
Attributes:
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
otherwise, it is input_dim * num_freqs * 2.
"""
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
"""The initialization"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = F.scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class MLP(nn.Module):
def __init__(
self, *,
width: int,
expand_ratio: int = 4,
output_width: int = None,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.c_fc = ops.Linear(width, width * expand_ratio)
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
self.gelu = nn.GELU()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
data_width: Optional[int] = None,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
kv_cache: bool = False,
):
super().__init__()
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = ops.Linear(width, width, bias=qkv_bias)
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadCrossAttention(
heads=heads,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.kv_cache = kv_cache
self.data = None
def forward(self, x, data):
x = self.c_q(x)
if self.kv_cache:
if self.data is None:
self.data = self.c_kv(data)
logging.info('Save kv cache,this should be called only once for one mesh')
data = self.data
else:
data = self.c_kv(data)
x = self.attention(x, data)
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
data_width: Optional[int] = None,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
width=width,
heads=heads,
data_width=data_width,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class QKVMultiheadAttention(nn.Module):
def __init__(
self,
*,
heads: int,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.heads = heads
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
heads=heads,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.drop_path(self.c_proj(x))
return x
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0,
):
super().__init__()
self.attn = MultiheadAttention(
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
width: int,
layers: int,
heads: int,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class CrossAttentionDecoder(nn.Module):
def __init__(
self,
*,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
downsample_ratio: int = 1,
enable_ln_post: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary"
):
super().__init__()
self.enable_ln_post = enable_ln_post
self.fourier_embedder = fourier_embedder
self.downsample_ratio = downsample_ratio
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
mlp_expand_ratio=mlp_expand_ratio,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
if self.enable_ln_post:
self.ln_post = ops.LayerNorm(width)
self.output_proj = ops.Linear(width, out_channels)
self.label_type = label_type
self.count = 0
def forward(self, queries=None, query_embeddings=None, latents=None):
if query_embeddings is None:
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
self.count += query_embeddings.shape[1]
if self.downsample_ratio != 1:
latents = self.latents_proj(latents)
x = self.cross_attn_decoder(query_embeddings, latents)
if self.enable_ln_post:
x = self.ln_post(x)
occ = self.output_proj(x)
return occ
class ShapeVAE(nn.Module):
def __init__(
self,
*,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
width=width,
layers=num_decoder_layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.geo_decoder = CrossAttentionDecoder(
fourier_embedder=self.fourier_embedder,
out_channels=1,
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
downsample_ratio=geo_decoder_downsample_ratio,
enable_ln_post=self.geo_decoder_ln_post,
width=width // geo_decoder_downsample_ratio,
heads=heads // geo_decoder_downsample_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
)
self.volume_decoder = VanillaVolumeDecoder()
self.scale_factor = scale_factor
def decode(self, latents, **kwargs):
latents = self.post_kl(latents.movedim(-2, -1))
latents = self.transformer(latents)
bounds = kwargs.get("bounds", 1.01)
num_chunks = kwargs.get("num_chunks", 8000)
octree_resolution = kwargs.get("octree_resolution", 256)
enable_pbar = kwargs.get("enable_pbar", True)
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
def encode(self, x):
return None

View File

@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
control=None,
transformer_options={},
) -> Tensor:
@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
if self.params.guidance_embed:
if guidance is not None:
@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet
control_i = control.get("input")
@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
if control is not None: # Controlnet
control_o = control.get("output")
@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
return out

View File

@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -464,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
tensor_layout = "HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
@ -472,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
tensor_layout = "NHD"
if mask is not None:
# add a batch dimension if there isn't already one
@ -482,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
@ -496,6 +513,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
@ -504,6 +578,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
@ -770,6 +847,7 @@ class SpatialTransformer(nn.Module):
if not isinstance(context, list):
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
x = self.norm(x)
if not self.use_linear:
@ -885,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer):
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
spatial_context = None
if exists(context):

View File

@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
context,
clip_fea=None,
freqs=None,
transformer_options={},
):
r"""
Forward pass through the diffusion model
@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
freqs=freqs,
context=context)
for block in self.blocks:
x = block(x, **kwargs)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
# head
x = self.head(x, e)
@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""

View File

@ -1,4 +1,5 @@
import torch
import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
return sd_out
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
return sd

View File

@ -36,6 +36,8 @@ import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.model_management
import comfy.patcher_extension
@ -58,6 +60,7 @@ class ModelType(Enum):
FLOW = 6
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
@ -88,6 +91,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
class ModelSampling(s, c):
pass
@ -108,7 +113,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
@ -139,6 +144,7 @@ class BaseModel(torch.nn.Module):
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
@ -600,6 +606,19 @@ class SDXL_instructpix2pix(IP2P, SDXL):
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class Lotus(BaseModel):
def extra_conds(self, **kwargs):
out = {}
cross_attn = kwargs.get("cross_attn", None)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
device = kwargs["device"]
task_emb = torch.tensor([1, 0]).float().to(device)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
out['y'] = comfy.conds.CONDRegular(task_emb)
return out
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
super().__init__(model_config, model_type, device=device)
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
@ -898,13 +917,31 @@ class HunyuanVideo(BaseModel):
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
guiding_frame_index = kwargs.get("guiding_frame_index", None)
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
return out
def scale_latent_inpaint(self, latent_image, **kwargs):
return latent_image
class HunyuanVideoI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image", "mask_inverted")
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",)
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
@ -955,28 +992,41 @@ class WAN21(BaseModel):
self.image_to_video = image_to_video
def concat_cond(self, **kwargs):
if not self.image_to_video:
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
for i in range(0, image.shape[1], 16):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video or extra_channels == image.shape[1]:
return image
if image.shape[1] > (extra_channels - 4):
image = image[:, :(extra_channels - 4)]
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
@ -992,3 +1042,35 @@ class WAN21(BaseModel):
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 5.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
if conditioning_llama3 is not None:
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
return out

View File

@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@ -323,6 +323,40 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "t2v"
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["qkv_bias"] = True
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
dit_config["attention_head_dim"] = 128
dit_config["axes_dims_rope"] = [64, 32, 32]
dit_config["caption_channels"] = [4096, 4096]
dit_config["max_resolution"] = [128, 128]
dit_config["in_channels"] = 16
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
dit_config["num_attention_heads"] = 20
dit_config["num_routed_experts"] = 4
dit_config["num_activated_experts"] = 2
dit_config["num_layers"] = 16
dit_config["num_single_layers"] = 32
dit_config["out_channels"] = 16
dit_config["patch_size"] = 2
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -471,6 +505,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
return model_config
@ -663,8 +701,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True

View File

@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
total_vram = 0
def get_supported_float8_types():
float8_types = []
try:
float8_types.append(torch.float8_e4m3fn)
except:
pass
try:
float8_types.append(torch.float8_e4m3fnuz)
except:
pass
try:
float8_types.append(torch.float8_e5m2)
except:
pass
try:
float8_types.append(torch.float8_e5m2fnuz)
except:
pass
try:
float8_types.append(torch.float8_e8m0fnu)
except:
pass
return float8_types
FLOAT8_TYPES = get_supported_float8_types()
xpu_available = False
torch_version = ""
try:
@ -186,12 +212,21 @@ def get_total_memory(dev=None, torch_total_too=False):
else:
return mem_total
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch_version))
mac_ver = mac_version()
if mac_ver is not None:
logging.info("Mac Version {}".format(mac_ver))
except:
pass
@ -581,7 +616,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if vram_set_state == VRAMState.NO_VRAM:
@ -692,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e5m2
fp8_dtype = None
try:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if weight_dtype in FLOAT8_TYPES:
fp8_dtype = weight_dtype
except:
pass
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
@ -791,6 +823,8 @@ def text_encoder_dtype(device=None):
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.bf16_text_enc:
return torch.bfloat16
elif args.fp32_text_enc:
return torch.float32
@ -921,6 +955,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled():
return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled():
global directml_enabled
global cpu_state
@ -969,12 +1006,6 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
@ -1206,6 +1237,8 @@ def soft_empty_cache(force=False):
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -747,6 +747,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
@ -770,6 +771,10 @@ class ModelPatcher:
move_weight = False
break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
@ -1089,7 +1094,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
self.unpatch_hooks()
if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None
@ -1100,12 +1104,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights:
if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
@ -1116,6 +1124,8 @@ class ModelPatcher:
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@ -1172,12 +1182,18 @@ class ModelPatcher:
del out_weight
del weight
def unpatch_hooks(self) -> None:
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
self.current_hooks = None
return
keys = list(self.hook_backup.keys())
if whitelist_keys_set:
for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))

View File

@ -69,6 +69,15 @@ class CONST:
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma)
class X0(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
super().__init__()

View File

@ -17,9 +17,11 @@
"""
import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@ -145,6 +147,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -242,6 +263,9 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
class RMSNorm(disable_weight_init.RMSNorm):
comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
@ -308,6 +332,7 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
@ -355,10 +380,29 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if (
fp8_compute and
@ -367,6 +411,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
):
return fp8_ops
if (
PerformanceFeature.CublasOps in args.fast and
CUBLAS_IS_AVAILABLE and
weight_dtype == torch.float16 and
(compute_dtype == torch.float16 or compute_dtype is None)
):
logging.info("Using cublas ops")
return cublas_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init

View File

@ -48,6 +48,7 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"

54
comfy/rmsnorm.py Normal file
View File

@ -0,0 +1,54 @@
import torch
import comfy.model_management
import numbers
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=None,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
def forward(self, x):
return rms_norm(x, self.weight, self.eps)

View File

@ -106,6 +106,13 @@ def cleanup_additional_models(models):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)

View File

@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation"]
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.hunyuan3d.vae
import yaml
import math
@ -40,6 +41,7 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.model_patcher
import comfy.lora
@ -264,6 +266,7 @@ class VAE:
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.downscale_index_formula = None
self.upscale_index_formula = None
@ -336,6 +339,7 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.disable_offload = True
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
@ -412,6 +416,17 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -440,6 +455,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
@ -494,18 +513,19 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
@ -525,8 +545,9 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
@ -553,13 +574,14 @@ class VAE:
return output.movedim(1, -1)
def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
@ -585,6 +607,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
@ -592,7 +615,7 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
args = {}
if tile_x is not None:
@ -831,6 +854,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif len(clip_data) == 4:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
parameters = 0
for c in clip_data:
@ -899,7 +925,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:

View File

@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
LAYERS = [
"last",
"pooled",
"hidden"
"hidden",
"all"
]
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if "model_name" not in model_options:
model_options = {**model_options, "model_name": "clip_l"}
if isinstance(textmodel_json_config, dict):
config = textmodel_json_config
@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
for k, v in te_model_options.items():
config[k] = v
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
if self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
@ -158,71 +167,98 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, numbers.Integral):
tokens_temp += [int(y)]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
else:
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
self.transformer.set_input_embeddings(new_embedding)
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == cmp_token:
embeds_out = []
attention_masks = []
num_tokens = []
for x in tokens:
attention_mask = []
tokens_temp = []
other_embeds = []
eos = False
index = 0
for y in x:
if isinstance(y, numbers.Integral):
if eos:
attention_mask.append(0)
else:
attention_mask.append(1)
token = int(y)
tokens_temp += [token]
if not eos and token == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break
attention_mask[-1] = 0
eos = True
else:
other_embeds.append((index, y))
index += 1
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
if emb is None:
index += -1
continue
ind = index + o[0]
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
emb_shape = emb.shape[1]
if emb.shape[-1] == tokens_embed.shape[-1]:
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
else:
index += -1
pad_extra += emb_shape
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
if pad_extra > 0:
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
attention_mask = attention_mask + [0] * pad_extra
embeds_out.append(tokens_embed)
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
if self.layer == "last":
z = outputs[0].float()
@ -425,7 +461,7 @@ class SDTokenizer:
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = max_length
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.end_token = None
@ -623,6 +659,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
model_options = {**model_options, "model_name": self.clip}
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()

View File

@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@ -41,8 +41,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)

View File

@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
class LotusD(SD20):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"use_temporal_attention": False,
"adm_in_channels": 4,
"in_channels": 4,
}
unet_extra_config = {
"num_classes": 'sequential'
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lotus(self, device=device)
class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
@ -826,6 +842,16 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
class HunyuanVideoI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 33,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoI2V(self, device=device)
return out
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
@ -921,7 +947,7 @@ class WAN21_T2V(supported_models_base.BASE):
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@ -943,12 +969,92 @@ class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
class WAN21_FunControl2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 48,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
}
unet_extra_config = {}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
memory_usage_factor = 3.5
clip_vision_prefix = "conditioner.main_image_encoder.model."
vae_key_prefix = ["vae."]
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Hunyuan3Dv2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
"depth": 8,
}
latent_format = latent_formats.Hunyuan3Dv2mini
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
}
sampling_settings = {
"shift": 3.0,
}
sampling_settings = {
}
# memory_usage_factor = 1.2 # TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HiDream(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
models += [SVD_img2vid]

View File

@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -93,7 +93,10 @@ class BertEmbeddings(torch.nn.Module):
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, input_tokens, token_type_ids=None, dtype=None):
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None:
@ -113,8 +116,8 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])

View File

@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -9,14 +9,13 @@ import os
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])

View File

@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -0,0 +1,150 @@
from . import hunyuan_video
from . import sd3_clip
from comfy import sd1_clip
from comfy import sdxl_clip
import comfy.model_management
import torch
import logging
class HiDreamTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data)
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HiDreamTEModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
if llama:
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
if "vocab_size" not in model_options:
model_options["vocab_size"] = 128256
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
self.dtypes.add(dtype_llama)
else:
self.llama = None
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
def set_clip_options(self, options):
if self.clip_l is not None:
self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options)
if self.llama is not None:
self.llama.set_clip_options(options)
def reset_clip_options(self):
if self.clip_l is not None:
self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options()
if self.llama is not None:
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_llama = token_weight_pairs["llama"]
lg_out = None
pooled = None
extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
else:
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
if self.llama is not None:
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
ll_out, ll_pooled = ll_output[:2]
ll_out = ll_out[:, 1:]
if t5_out is None:
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
if ll_out is None:
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
extra["conditioning_llama3"] = ll_out
return t5_out, pooled, extra
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
return self.t5xxl.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -4,6 +4,7 @@ import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast
import torch
import os
import numbers
def llama_detect(state_dict, prefix=""):
@ -20,36 +21,49 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
if vocab_size is not None:
textmodel_json_config["vocab_size"] = vocab_size
model_options = {**model_options, "model_name": "llama"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, llama_template=None, **kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
if llama_template is None:
llama_text = "{}{}".format(self.llama_template, text)
llama_text = self.llama_template.format(text)
else:
llama_text = "{}{}".format(llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
llama_text = llama_template.format(text)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
embed_count = 0
for r in llama_text_tokens:
for i in range(len(r)):
if r[i][0] == 128257:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
embed_count += 1
out["llama"] = llama_text_tokens
return out
def untokenize(self, token_weight_pair):
@ -63,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])
@ -83,20 +96,51 @@ class HunyuanVideoClipModel(torch.nn.Module):
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0
for i, v in enumerate(token_weight_pairs_llama[0]):
if v[0] == 128007: # <|end_header_id|>
template_end = i
extra_template_end = 0
extra_sizes = 0
user_end = 9999999999999
images = []
tok_pairs = token_weight_pairs_llama[0]
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 128006:
if tok_pairs[i + 1][0] == 882:
if tok_pairs[i + 2][0] == 128007:
template_end = i + 2
user_end = -1
if elem == 128009 and user_end == -1:
user_end = i + 1
else:
if elem.get("original_type") == "image":
elem_size = elem.get("data").shape[0]
if template_end > 0:
if user_end == -1:
extra_template_end += elem_size - 1
else:
image_start = i + extra_sizes
image_end = i + elem_size + extra_sizes
images.append((image_start, image_end, elem.get("image_interleave", 1)))
extra_sizes += elem_size - 1
if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
if tok_pairs[template_end + 1][0] == 271:
template_end += 2
llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
if len(images) > 0:
out = []
for i in images:
out.append(llama_out[:, i[0]: i[1]: i[2]])
llama_output = torch.cat(out + [llama_output], dim=1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_out, l_pooled, llama_extra_out
return llama_output, l_pooled, llama_extra_out
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:

View File

@ -9,24 +9,26 @@ import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
model_options = {**model_options, "model_name": "hydit_clip"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
model_options = {**model_options, "model_name": "mt5xl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@ -35,7 +37,7 @@ class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}

View File

@ -241,7 +241,10 @@ class Llama2_(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
@ -265,11 +268,17 @@ class Llama2_(nn.Module):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
all_intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@ -280,6 +289,12 @@ class Llama2_(nn.Module):
intermediate = x.clone()
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)

View File

@ -1,30 +1,27 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
else:
model_name = "clip_g"
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
if w is not None:
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
model_name = "clip_g"
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
model_name = "clip_l"
else:
model_name = "clip_l"
if w is not None:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
model_config = model_options.get("model_config", {})
model_config["max_position_embeddings"] = w.shape[0]
model_options["{}_model_config".format(model_name)] = model_config
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
return tokenizer_data, model_options

View File

@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -6,7 +6,7 @@ import comfy.text_encoders.llama
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""):
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data)
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module):
super().__init__()
self.dtypes = set()
if clip_l:
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None

View File

@ -239,8 +239,11 @@ class T5(torch.nn.Module):
def set_input_embeddings(self, embeddings):
self.shared = embeddings
def forward(self, input_ids, *args, **kwargs):
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
if input_ids is None:
x = embeds
else:
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs)
return self.encoder(x, attention_mask=attention_mask, **kwargs)

View File

@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -316,3 +316,156 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import torchaudio
import torch
import comfy.model_management
@ -10,6 +12,7 @@ import random
import hashlib
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types import FileLocator
class EmptyLatentAudio:
def __init__(self):
@ -164,7 +167,7 @@ class SaveAudio:
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
results: list[FileLocator] = []
metadata = {}
if not args.disable_metadata:

45
comfy_extras/nodes_cfg.py Normal file
View File

@ -0,0 +1,45 @@
import torch
# https://github.com/WeichenFan/CFG-Zero-star
def optimized_scale(positive, negative):
positive_flat = positive.reshape(positive.shape[0], -1)
negative_flat = negative.reshape(negative.shape[0], -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
class CFGZeroStar:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
def patch(self, model):
m = model.clone()
def cfg_zero_star(args):
guidance_scale = args['cond_scale']
x = args['input']
cond_p = args['cond_denoised']
uncond_p = args['uncond_denoised']
out = args["denoised"]
alpha = optimized_scale(x - cond_p, x - uncond_p)
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar
}

View File

@ -454,7 +454,7 @@ class SamplerCustom:
return {"required":
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
@ -605,8 +605,14 @@ class DisableNoise:
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
return {
"required": {
"noise_seed": ("INT", {
"default": 0,
"min": 0,
"max": 0xffffffffffffffff,
"control_after_generate": True,
}),
}
}

View File

@ -0,0 +1,32 @@
import folder_paths
import comfy.sd
import comfy.model_management
class QuadrupleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
}

View File

@ -1,4 +1,5 @@
import nodes
import node_helpers
import torch
import comfy.model_management
@ -38,7 +39,83 @@ class EmptyHunyuanLatentVideo:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
class TextEncodeHunyuanVideo_ImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), )
class HunyuanImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image)
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent
return (positive, out_latent)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
}

View File

@ -0,0 +1,634 @@
import torch
import os
import json
import struct
import numpy as np
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/3d"
def generate(self, resolution, batch_size):
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision_output):
embeds = clip_vision_output.last_hidden_state
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
class Hunyuan3Dv2ConditioningMultiView:
@classmethod
def INPUT_TYPES(s):
return {"required": {},
"optional": {"front": ("CLIP_VISION_OUTPUT",),
"left": ("CLIP_VISION_OUTPUT",),
"back": ("CLIP_VISION_OUTPUT",),
"right": ("CLIP_VISION_OUTPUT",), }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, front=None, left=None, back=None, right=None):
all_embeds = [front, left, back, right]
out = []
pos_embeds = None
for i, e in enumerate(all_embeds):
if e is not None:
if pos_embeds is None:
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
embeds = torch.cat(out, dim=1)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"vae": ("VAE", ),
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
}}
RETURN_TYPES = ("VOXEL",)
FUNCTION = "decode"
CATEGORY = "latent/3d"
def decode(self, vae, samples, num_chunks, octree_resolution):
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
binary = (voxels > threshold).float()
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
D, H, W = binary.shape
neighbors = torch.tensor([
[0, 0, 1],
[0, 0, -1],
[0, 1, 0],
[0, -1, 0],
[1, 0, 0],
[-1, 0, 0]
], device=device)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
solid_mask = binary.flatten() > 0
solid_indices = voxel_indices[solid_mask]
corner_offsets = [
torch.tensor([
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
], device=device),
torch.tensor([
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
], device=device),
torch.tensor([
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
], device=device),
torch.tensor([
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
], device=device),
torch.tensor([
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
], device=device),
torch.tensor([
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
], device=device)
]
all_vertices = []
all_indices = []
vertex_count = 0
for face_idx, offset in enumerate(neighbors):
neighbor_indices = solid_indices + offset
padded_indices = neighbor_indices + 1
is_exposed = padded[
padded_indices[:, 0],
padded_indices[:, 1],
padded_indices[:, 2]
] == 0
if not is_exposed.any():
continue
exposed_indices = solid_indices[is_exposed]
corners = corner_offsets[face_idx].unsqueeze(0)
face_vertices = exposed_indices.unsqueeze(1) + corners
all_vertices.append(face_vertices.reshape(-1, 3))
num_faces = exposed_indices.shape[0]
face_indices = torch.arange(
vertex_count,
vertex_count + 4 * num_faces,
device=device
).reshape(-1, 4)
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
vertex_count += 4 * num_faces
if len(all_vertices) > 0:
vertices = torch.cat(all_vertices, dim=0)
faces = torch.cat(all_indices, dim=0)
else:
vertices = torch.zeros((1, 3))
faces = torch.zeros((1, 3))
v_min = 0
v_max = max(voxels.shape)
vertices = vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
vertices = vertices / scale
vertices = torch.fliplr(vertices)
return vertices, faces
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
D, H, W = voxels.shape
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
corner_offsets = torch.tensor([
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)
has_outside = torch.any(~corner_signs, dim=1)
contains_surface = has_inside & has_outside
active_cells = cell_positions[contains_surface]
active_signs = corner_signs[contains_surface]
active_values = corner_values[contains_surface]
if active_cells.shape[0] == 0:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
edges = torch.tensor([
[0, 1], [0, 2], [0, 4], [1, 3],
[1, 5], [2, 3], [2, 6], [3, 7],
[4, 5], [4, 6], [5, 7], [6, 7]
], device=device)
cell_vertices = {}
progress = comfy.utils.ProgressBar(100)
for edge_idx, (e1, e2) in enumerate(edges):
progress.update(1)
crossing = active_signs[:, e1] != active_signs[:, e2]
if not crossing.any():
continue
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
v1 = active_values[cell_indices, e1]
v2 = active_values[cell_indices, e2]
t = torch.zeros_like(v1, device=device)
denom = v2 - v1
valid = denom != 0
t[valid] = (threshold - v1[valid]) / denom[valid]
t[~valid] = 0.5
p1 = corner_offsets[e1].float()
p2 = corner_offsets[e2].float()
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
for i, point in zip(cell_indices.tolist(), intersection):
if i not in cell_vertices:
cell_vertices[i] = []
cell_vertices[i].append(point)
# Calculate the final vertices as the average of intersection points for each cell
vertices = []
vertex_lookup = {}
vert_progress_mod = round(len(cell_vertices)/50)
for i, points in cell_vertices.items():
if not i % vert_progress_mod:
progress.update(1)
if points:
vertex = torch.stack(points).mean(dim=0)
vertex = vertex + active_cells[i].float()
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
vertices.append(vertex)
if not vertices:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
final_vertices = torch.stack(vertices)
inside_corners_mask = active_signs
outside_corners_mask = ~active_signs
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
for i in range(8):
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
inside_pos /= inside_counts
outside_pos /= outside_counts
gradients = inside_pos - outside_pos
pos_dirs = torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
], device=device)
cross_products = [
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
for i in range(3) for j in range(i+1, 3)
]
faces = []
all_keys = set(vertex_lookup.keys())
face_progress_mod = round(len(active_cells)/38*3)
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
dir_i = pos_dirs[i]
dir_j = pos_dirs[j]
cross_product = cross_products[pair_idx]
ni_positions = active_cells + dir_i
nj_positions = active_cells + dir_j
diag_positions = active_cells + dir_i + dir_j
alignments = torch.matmul(gradients, cross_product)
valid_quads = []
quad_indices = []
for idx, active_cell in enumerate(active_cells):
if not idx % face_progress_mod:
progress.update(1)
cell_key = tuple(active_cell.tolist())
ni_key = tuple(ni_positions[idx].tolist())
nj_key = tuple(nj_positions[idx].tolist())
diag_key = tuple(diag_positions[idx].tolist())
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
v0 = vertex_lookup[cell_key]
v1 = vertex_lookup[ni_key]
v2 = vertex_lookup[nj_key]
v3 = vertex_lookup[diag_key]
valid_quads.append((v0, v1, v2, v3))
quad_indices.append(idx)
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
cell_idx = quad_indices[q_idx]
if alignments[cell_idx] > 0:
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
else:
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
if faces:
faces = torch.stack(faces)
else:
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
v_min = 0
v_max = max(D, H, W)
final_vertices = final_vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
final_vertices = final_vertices / scale
final_vertices = torch.fliplr(final_vertices)
return final_vertices, faces
class MESH:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
class VoxelToMeshBasic:
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, threshold):
vertices = []
faces = []
for x in voxel.data:
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
class VoxelToMesh:
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"algorithm": (["surface net", "basic"], ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, algorithm, threshold):
vertices = []
faces = []
if algorithm == "basic":
mesh_function = voxel_to_mesh
elif algorithm == "surface net":
mesh_function = voxel_to_mesh_surfnet
for x in voxel.data:
v, f = mesh_function(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
def save_glb(vertices, faces, filepath, metadata=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
# Convert tensors to numpy arrays
vertices_np = vertices.cpu().numpy().astype(np.float32)
faces_np = faces.cpu().numpy().astype(np.uint32)
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
buffer_data = vertices_buffer_padded + indices_buffer_padded
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [
{
"byteLength": len(buffer_data)
}
],
"bufferViews": [
{
"buffer": 0,
"byteOffset": vertices_byte_offset,
"byteLength": vertices_byte_length,
"target": 34962 # ARRAY_BUFFER
},
{
"buffer": 0,
"byteOffset": indices_byte_offset,
"byteLength": indices_byte_length,
"target": 34963 # ELEMENT_ARRAY_BUFFER
}
],
"accessors": [
{
"bufferView": 0,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(vertices_np),
"type": "VEC3",
"max": vertices_np.max(axis=0).tolist(),
"min": vertices_np.min(axis=0).tolist()
},
{
"bufferView": 1,
"byteOffset": 0,
"componentType": 5125, # UNSIGNED_INT
"count": faces_np.size,
"type": "SCALAR"
}
],
"meshes": [
{
"primitives": [
{
"attributes": {
"POSITION": 0
},
"indices": 1,
"mode": 4 # TRIANGLES
}
]
}
],
"nodes": [
{
"mesh": 0
}
],
"scenes": [
{
"nodes": [0]
}
],
"scene": 0
}
if metadata is not None:
gltf["asset"]["extras"] = metadata
# Convert the JSON to bytes
gltf_json = json.dumps(gltf).encode('utf8')
def pad_json_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b' ' * padding_length
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
# Create the GLB header
# Magic glTF
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
# Create JSON chunk header (chunk type 0)
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
# Create BIN chunk header (chunk type 1)
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
# Write the GLB file
with open(filepath, 'wb') as f:
f.write(glb_header)
f.write(json_chunk_header)
f.write(gltf_json_padded)
f.write(bin_chunk_header)
f.write(buffer_data)
return filepath
class SaveGLB:
@classmethod
def INPUT_TYPES(s):
return {"required": {"mesh": ("MESH", ),
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "3d"
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return {"ui": {"3d": results}}
NODE_CLASS_MAPPINGS = {
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"VoxelToMesh": VoxelToMesh,
"SaveGLB": SaveGLB,
}

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import nodes
import folder_paths
from comfy.cli_args import args
@ -9,6 +11,8 @@ import numpy as np
import json
import os
from comfy.comfy_types import FileLocator
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@ -99,7 +103,7 @@ class SaveAnimatedWEBP:
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
results: list[FileLocator] = []
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()

View File

@ -19,12 +19,10 @@ class Load3D():
"image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("image", "mask", "mesh_path")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
FUNCTION = "process"
EXPERIMENTAL = True
@ -34,12 +32,16 @@ class Load3D():
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file, normal_image, lineart_image
class Load3DAnimation():
@classmethod
@ -55,12 +57,10 @@ class Load3DAnimation():
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("image", "mask", "mesh_path")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
FUNCTION = "process"
EXPERIMENTAL = True
@ -70,20 +70,20 @@ class Load3DAnimation():
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file, normal_image
class Preview3D():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
OUTPUT_NODE = True
@ -102,8 +102,6 @@ class Preview3DAnimation():
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
OUTPUT_NODE = True

File diff suppressed because one or more lines are too long

View File

@ -99,12 +99,13 @@ class LTXVAddGuide:
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
"Negative values are counted from the end of the video."}),
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
@ -127,11 +128,12 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels)
return encode_pixels, t
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1:
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
@ -191,14 +193,9 @@ class LTXVAddGuide:
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
if frame_idx == 0:
latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe(
@ -252,6 +249,8 @@ class LTXVCropGuides:
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
@ -447,7 +446,6 @@ class LTXVPreprocess:
CATEGORY = "image"
def preprocess(self, image, img_compression):
if img_compression > 0:
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))

View File

@ -2,6 +2,7 @@ import numpy as np
import scipy.ndimage
import torch
import comfy.utils
import node_helpers
from nodes import MAX_RESOLUTION
@ -87,6 +88,7 @@ class ImageCompositeMasked:
CATEGORY = "image"
def composite(self, destination, source, x, y, resize_source, mask = None):
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
return (output,)

View File

@ -20,10 +20,6 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class X0(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
@ -56,7 +52,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@ -77,7 +73,9 @@ class ModelSamplingDiscrete:
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0":
sampling_type = X0
sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

View File

@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["head."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
}

View File

@ -2,6 +2,7 @@ import torch
import comfy.model_management
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color
class Morphology:
@ -40,8 +41,45 @@ class Morphology:
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,)
class ImageRGBToYUV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("Y", "U", "V")
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB:
@classmethod
def INPUT_TYPES(s):
return {"required": {"Y": ("IMAGE",),
"U": ("IMAGE",),
"V": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -0,0 +1,56 @@
# from https://github.com/bebebe666/OptimalSteps
import numpy as np
import torch
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = np.linspace(0, 1, len(t_steps))
ys = np.log(t_steps[::-1])
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
}
class OptimalStepsScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["FLUX", "Wan"], ),
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
if (steps + 1) != len(sigmas):
sigmas = loglinear_interp(sigmas, steps + 1)
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler,
}

View File

@ -6,7 +6,7 @@ import math
import comfy.utils
import comfy.model_management
import node_helpers
class Blend:
def __init__(self):
@ -34,6 +34,7 @@ class Blend:
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)

View File

@ -0,0 +1,79 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
class String(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class Int(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"control_after_generate": True})},
}
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
class Boolean(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}

View File

@ -1,9 +1,12 @@
from __future__ import annotations
import os
import av
import torch
import folder_paths
import json
from fractions import Fraction
from comfy.comfy_types import FileLocator
class SaveWEBM:
@ -62,7 +65,7 @@ class SaveWEBM:
container.mux(stream.encode())
container.close()
results = [{
results: list[FileLocator] = [{
"filename": file,
"subfolder": subfolder,
"type": self.type

View File

@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
class WanImageToVideo:
@ -49,6 +50,110 @@ class WanImageToVideo:
return (positive, negative, out_latent)
class WanFunControlToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"control_video": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.21"
__version__ = "0.3.28"

View File

@ -15,7 +15,7 @@ import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
@ -59,20 +59,27 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
@ -80,6 +87,17 @@ class CacheSet:
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(self.lru_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.status_messages = []
self.success = True
@ -634,6 +653,13 @@ def validate_inputs(prompt, item, validated):
continue
else:
try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
if type_input == "INT":
val = int(val)
inputs[x] = val
@ -768,7 +794,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return (False, error, [], {})
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@ -779,7 +805,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return (False, error, [], {})
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
@ -791,7 +817,7 @@ def validate_prompt(prompt):
"details": "",
"extra_info": {}
}
return (False, error, [], [])
return (False, error, [], {})
good_outputs = set()
errors = []

View File

@ -85,6 +85,7 @@ cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
"fbx" : "model",
}
def map_legacy(folder_name: str) -> str:
@ -140,11 +141,14 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
videos = filter_files_content_types(files, ["video"])
Note:
- 'model' in MIME context refers to 3D models, not files containing trained weights and parameters
"""
global extension_mimetypes_cache
result = []

16
main.py
View File

@ -10,6 +10,7 @@ from app.logger import setup_logger
import itertools
import utils.extra_config
import logging
import sys
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
@ -139,6 +140,7 @@ from server import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
import app.logger
def cuda_malloc_warning():
@ -155,7 +157,13 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@ -294,10 +302,14 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui()
try:
event_loop.run_until_complete(start_all_func())
x = start_all_func()
app.logger.print_startup_warnings()
event_loop.run_until_complete(x)
except KeyboardInterrupt:
logging.info("\nStopped server")

View File

@ -44,3 +44,11 @@ def string_to_torch_dtype(string):
return torch.float16
if string == "bf16":
return torch.bfloat16
def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1))
destination[..., -1] = 1.0
return destination, source

View File

@ -25,7 +25,7 @@ import comfy.sample
import comfy.sd
import comfy.utils
import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
import comfy.clip_vision
@ -479,7 +479,7 @@ class SaveLatent:
file = f"{filename}_{counter:05}_.latent"
results = list()
results: list[FileLocator] = []
results.append({
"filename": file,
"subfolder": subfolder,
@ -489,7 +489,7 @@ class SaveLatent:
file = os.path.join(full_output_folder, file)
output = {}
output["latent_tensor"] = samples["samples"]
output["latent_tensor"] = samples["samples"].contiguous()
output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata)
@ -770,6 +770,7 @@ class VAELoader:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
return (vae,)
class ControlNetLoader:
@ -785,6 +786,8 @@ class ControlNetLoader:
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
if controlnet is None:
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
return (controlnet,)
class DiffControlNetLoader:
@ -1005,6 +1008,8 @@ class CLIPVisionLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
if clip_vision is None:
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
return (clip_vision,)
class CLIPVisionEncode:
@ -1519,7 +1524,7 @@ class KSampler:
return {
"required": {
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
@ -1547,7 +1552,7 @@ class KSamplerAdvanced:
return {"required":
{"model": ("MODEL",),
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
@ -1649,6 +1654,7 @@ class LoadImage:
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {"required":
{"image": (sorted(files), {"image_upload": True})},
}
@ -1687,6 +1693,9 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
@ -1785,14 +1794,7 @@ class LoadImageOutput(LoadImage):
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True
FUNCTION = "load_image_output"
def load_image_output(self, image):
return self.load_image(f"{image} [output]")
@classmethod
def VALIDATE_INPUTS(s, image):
return True
FUNCTION = "load_image"
class ImageScale:
@ -2129,21 +2131,25 @@ def get_module_name(module_path: str) -> str:
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = os.path.basename(module_path)
module_name = get_module_name(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]
sys_module_name = module_name
elif os.path.isdir(module_path):
sys_module_name = module_path.replace(".", "_x_")
try:
logging.debug("Trying to load custom node {}".format(module_path))
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path)
module_dir = os.path.split(module_path)[0]
else:
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py"))
module_dir = module_path
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
sys.modules[sys_module_name] = module
module_spec.loader.exec_module(module)
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
@ -2270,6 +2276,12 @@ def init_builtin_extra_nodes():
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
"nodes_optimalsteps.py",
"nodes_hidream.py"
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.21"
version = "0.3.28"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.10.17
comfyui-frontend-package==1.15.13
torch
torchsde
torchvision

View File

@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css'):
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@ -657,7 +657,13 @@ class PromptServer():
logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
error = {
"type": "no_prompt",
"message": "No prompt provided",
"details": "No prompt provided",
"extra_info": {}
}
return web.json_response({"error": error, "node_errors": {}}, status=400)
@routes.post("/queue")
async def post_queue(request):

View File

@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
assert frontend_path == FrontendManager.default_frontend_path()
def test_init_frontend_invalid_version():
@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
@pytest.fixture
def mock_os_functions():
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
patch('app.frontend_management.os.listdir') as mock_listdir, \
patch('app.frontend_management.os.rmdir') as mock_rmdir:
with (
patch("app.frontend_management.os.makedirs") as mock_makedirs,
patch("app.frontend_management.os.listdir") as mock_listdir,
patch("app.frontend_management.os.rmdir") as mock_rmdir,
):
mock_listdir.return_value = [] # Simulate empty directory
yield mock_makedirs, mock_listdir, mock_rmdir
@pytest.fixture
def mock_download():
with patch('app.frontend_management.download_release_asset_zip') as mock:
with patch("app.frontend_management.download_release_asset_zip") as mock:
mock.side_effect = Exception("Download failed") # Simulate download failure
yield mock
def test_finally_block(mock_os_functions, mock_download, mock_provider):
# Arrange
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
version_string = 'test-owner/test-repo@1.0.0'
version_string = "test-owner/test-repo@1.0.0"
# Act & Assert
with pytest.raises(Exception):
@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string)
def test_init_frontend_default_with_mocks():
# Arrange
version_string = DEFAULT_VERSION_STRING
# Act
with (
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/mocked/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/mocked/path"
mock_check.assert_called_once()
def test_init_frontend_fallback_on_error():
# Arrange
version_string = "test-owner/test-repo@1.0.0"
# Act
with (
patch.object(
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
),
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/default/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/default/path"
mock_check.assert_called_once()

View File

@ -1,14 +1,17 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
from folder_paths import filter_files_content_types, extension_mimetypes_cache
from unittest.mock import patch
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'],
'model': ['gltf', 'glb', 'obj', 'fbx', 'stl']
}
@ -22,7 +25,18 @@ def mock_dir(file_extensions):
yield directory
def test_categorizes_all_correctly(mock_dir, file_extensions):
@pytest.fixture
def patched_mimetype_cache(file_extensions):
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
new_cache = extension_mimetypes_cache.copy()
for extension in file_extensions["model"]:
new_cache[extension] = "model"
with patch("folder_paths.extension_mimetypes_cache", new_cache):
yield
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
@ -30,7 +44,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions):
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions):
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])