mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Compare commits
No commits in common. "master" and "v0.3.25" have entirely different histories.
@ -19,6 +19,5 @@
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
@ -69,8 +69,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
|
@ -9,14 +9,8 @@ class AppSettings():
|
||||
self.user_manager = user_manager
|
||||
|
||||
def get_settings(self, request):
|
||||
try:
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request,
|
||||
"comfy.settings.json"
|
||||
)
|
||||
except KeyError as e:
|
||||
logging.error("User settings not found.")
|
||||
raise web.HTTPUnauthorized() from e
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
if os.path.isfile(file):
|
||||
try:
|
||||
with open(file) as f:
|
||||
|
@ -11,62 +11,33 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
except ImportError:
|
||||
# TODO: Remove the check after roll out of 0.3.16
|
||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
||||
exit(-1)
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
else:
|
||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check frontend version: {e}")
|
||||
|
||||
try:
|
||||
frontend_version = tuple(map(int, comfyui_frontend_package.__version__.split(".")))
|
||||
except:
|
||||
frontend_version = (0,)
|
||||
pass
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
@ -162,28 +133,9 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-frontend-package is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
@ -204,9 +156,7 @@ comfyui-frontend-package is not installed.
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
@ -222,26 +172,17 @@ comfyui-frontend-package is not installed.
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT)
|
||||
/ f"{repo_owner}_{repo_name}"
|
||||
/ version.lstrip("v")
|
||||
)
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(
|
||||
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
|
||||
)
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
return expected_path
|
||||
|
||||
logging.info(
|
||||
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
|
||||
)
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
@ -284,5 +225,4 @@ comfyui-frontend-package is not installed.
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
@ -82,17 +82,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
STARTUP_WARNINGS = []
|
||||
|
||||
|
||||
def log_startup_warning(msg):
|
||||
logging.warning(msg)
|
||||
STARTUP_WARNINGS.append(msg)
|
||||
|
||||
|
||||
def print_startup_warnings():
|
||||
for s in STARTUP_WARNINGS:
|
||||
logging.warning(s)
|
||||
STARTUP_WARNINGS.clear()
|
||||
|
@ -79,7 +79,6 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
@ -101,14 +100,12 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@ -136,9 +133,8 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
|
||||
class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
|
@ -9,7 +9,6 @@ import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -35,12 +34,6 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
@ -49,11 +42,10 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
@ -110,21 +102,15 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -1,13 +0,0 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 512,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
@ -27,7 +26,6 @@ class IO(StrEnum):
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
COMBO = "COMBO"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
@ -68,7 +66,6 @@ class IO(StrEnum):
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class RemoteInputOptions(TypedDict):
|
||||
route: str
|
||||
"""The route to the remote source."""
|
||||
@ -83,14 +80,6 @@ class RemoteInputOptions(TypedDict):
|
||||
refresh: int
|
||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||
|
||||
|
||||
class MultiSelectOptions(TypedDict):
|
||||
placeholder: NotRequired[str]
|
||||
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||
chip: NotRequired[bool]
|
||||
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
@ -102,13 +91,9 @@ class InputTypeOptions(TypedDict):
|
||||
default: bool | str | float | int | list | tuple
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
||||
- defaultInput on required inputs should be dropped.
|
||||
- defaultInput on optional inputs should be replaced with forceInput.
|
||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
||||
"""
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
@ -148,22 +133,9 @@ class InputTypeOptions(TypedDict):
|
||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||
"""
|
||||
remote: RemoteInputOptions
|
||||
"""Specifies the configuration for a remote input.
|
||||
Available after ComfyUI frontend v1.9.7
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||
"""Specifies the configuration for a remote input."""
|
||||
control_after_generate: bool
|
||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||
options: NotRequired[list[str | int | float]]
|
||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||
Prefer:
|
||||
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||
Over:
|
||||
[["Option 1", "Option 2", "Option 3"]]
|
||||
"""
|
||||
multi_select: NotRequired[MultiSelectOptions]
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
@ -1,141 +0,0 @@
|
||||
import torch
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
in_features = out_features = dim
|
||||
hidden_features = int(dim * 4)
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
|
||||
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.weights_in(x)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.nn.functional.silu(x1) * x2
|
||||
return self.weights_out(x)
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
@ -1,21 +0,0 @@
|
||||
{
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 1536,
|
||||
"image_size": 518,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"layerscale_value": 1.0,
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_attention_heads": 24,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 40,
|
||||
"patch_size": 14,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
@ -1366,157 +1366,3 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
@ -456,13 +456,3 @@ class Wan21(LatentFormat):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import comfy.rmsnorm
|
||||
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
@ -12,5 +11,20 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
|
||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
rms_norm = comfy.rmsnorm.rms_norm
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
@ -159,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@ -195,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims)
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
@ -10,11 +10,10 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
@ -37,8 +36,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
@ -115,11 +115,8 @@ class Flux(nn.Module):
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
else:
|
||||
pe = None
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
|
@ -1,828 +0,0 @@
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.float()
|
||||
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(2)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=1024,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
|
||||
class PooledEmbed(nn.Module):
|
||||
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, pooled_embed):
|
||||
return self.pooled_embedder(pooled_embed)
|
||||
|
||||
|
||||
class TimestepEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timesteps, wdtype):
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
|
||||
|
||||
class OutEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, adaln_input):
|
||||
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
|
||||
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
dtype = image_tokens.dtype
|
||||
batch_size = image_tokens.shape[0]
|
||||
|
||||
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
|
||||
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
|
||||
value_i = attn.to_v(image_tokens)
|
||||
|
||||
inner_dim = key_i.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
if image_tokens_masks is not None:
|
||||
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
|
||||
|
||||
if not attn.single:
|
||||
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
|
||||
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
|
||||
value_t = attn.to_v_t(text_tokens)
|
||||
|
||||
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
num_image_tokens = query_i.shape[1]
|
||||
num_text_tokens = query_t.shape[1]
|
||||
query = torch.cat([query_i, query_t], dim=1)
|
||||
key = torch.cat([key_i, key_t], dim=1)
|
||||
value = torch.cat([value_i, value_t], dim=1)
|
||||
else:
|
||||
query = query_i
|
||||
key = key_i
|
||||
value = value_i
|
||||
|
||||
if query.shape[-1] == rope.shape[-3] * 2:
|
||||
query, key = apply_rope(query, key, rope)
|
||||
else:
|
||||
query_1, query_2 = query.chunk(2, dim=-1)
|
||||
key_1, key_2 = key.chunk(2, dim=-1)
|
||||
query_1, key_1 = apply_rope(query_1, key_1, rope)
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
hidden_states_i = attn.to_out(hidden_states_i)
|
||||
hidden_states_t = attn.to_out_t(hidden_states_t)
|
||||
return hidden_states_i, hidden_states_t
|
||||
else:
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class HiDreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
scale_qk: bool = True,
|
||||
eps: float = 1e-5,
|
||||
processor = None,
|
||||
out_dim: int = None,
|
||||
single: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
# super(Attention, self).__init__()
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.sliceable_head_dim = heads
|
||||
self.single = single
|
||||
|
||||
linear_cls = operations.Linear
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
|
||||
if not single:
|
||||
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
norm_image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
image_tokens = norm_image_tokens,
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
)
|
||||
|
||||
|
||||
class FeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * (
|
||||
(hidden_dim + multiple_of - 1) // multiple_of
|
||||
)
|
||||
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.top_k = num_activated_experts
|
||||
self.n_routed_experts = num_routed_experts
|
||||
|
||||
self.scoring_func = 'softmax'
|
||||
self.alpha = aux_loss_alpha
|
||||
self.seq_aux = False
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = False
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
pass
|
||||
# import torch.nn.init as init
|
||||
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
aux_loss = None
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MOEFeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
num_routed_experts: int,
|
||||
num_activated_experts: int,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
|
||||
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
|
||||
self.gate = MoEGate(
|
||||
embed_dim = dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.num_activated_experts = num_activated_experts
|
||||
|
||||
def forward(self, x):
|
||||
wtype = x.dtype
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if True: # self.training: # TODO: check which branch performs faster
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape).to(dtype=wtype)
|
||||
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.num_activated_experts
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# for fp16 and other dtype
|
||||
expert_cache = expert_cache.to(expert_out.dtype)
|
||||
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
|
||||
return expert_cache
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlockType:
|
||||
TransformerBlock = 1
|
||||
SingleTransformerBlock = 2
|
||||
|
||||
|
||||
class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor = HiDreamAttnProcessor_flashattn(),
|
||||
single = True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim = dim,
|
||||
hidden_dim = 4 * dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||
attn_output_i = self.attn1(
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
|
||||
image_tokens = ff_output_i + image_tokens
|
||||
return image_tokens
|
||||
|
||||
|
||||
class HiDreamImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
# nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
# nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor = HiDreamAttnProcessor_flashattn(),
|
||||
single = False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim = dim,
|
||||
hidden_dim = 4 * dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
|
||||
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
|
||||
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
|
||||
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
|
||||
|
||||
attn_output_i, attn_output_t = self.attn1(
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
text_tokens = gate_msa_t * attn_output_t + text_tokens
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
|
||||
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
|
||||
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
|
||||
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
|
||||
image_tokens = ff_output_i + image_tokens
|
||||
text_tokens = ff_output_t + text_tokens
|
||||
return image_tokens, text_tokens
|
||||
|
||||
|
||||
class HiDreamImageBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
block_type: BlockType = BlockType.TransformerBlock,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
block_classes = {
|
||||
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
|
||||
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
|
||||
}
|
||||
self.block = block_classes[block_type](
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
num_routed_experts,
|
||||
num_activated_experts,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_tokens: torch.FloatTensor,
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
image_tokens_masks,
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
)
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Optional[int] = None,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 16,
|
||||
num_single_layers: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 20,
|
||||
caption_channels: List[int] = None,
|
||||
text_emb_dim: int = 2048,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||
max_resolution: Tuple[int, int] = (128, 128),
|
||||
llama_layers: List[int] = None,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_single_layers = num_single_layers
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = self.num_attention_heads * self.attention_head_dim
|
||||
self.llama_layers = llama_layers
|
||||
|
||||
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size = patch_size,
|
||||
in_channels = in_channels,
|
||||
out_channels = self.inner_dim,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.double_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamImageBlock(
|
||||
dim = self.inner_dim,
|
||||
num_attention_heads = self.num_attention_heads,
|
||||
attention_head_dim = self.attention_head_dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
block_type = BlockType.TransformerBlock,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamImageBlock(
|
||||
dim = self.inner_dim,
|
||||
num_attention_heads = self.num_attention_heads,
|
||||
attention_head_dim = self.attention_head_dim,
|
||||
num_routed_experts = num_routed_experts,
|
||||
num_activated_experts = num_activated_experts,
|
||||
block_type = BlockType.SingleTransformerBlock,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(self.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||
caption_projection = []
|
||||
for caption_channel in caption_channels:
|
||||
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
||||
|
||||
def expand_timesteps(self, timesteps, batch_size, device):
|
||||
if not torch.is_tensor(timesteps):
|
||||
is_mps = device.type == "mps"
|
||||
if isinstance(timesteps, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
return timesteps
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.patch_size, p2=self.patch_size)
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
pz2 = self.patch_size * self.patch_size
|
||||
if isinstance(x, torch.Tensor):
|
||||
B = x.shape[0]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
else:
|
||||
B = len(x)
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
||||
|
||||
if img_sizes is not None:
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
x_masks[i, 0:img_size[0] * img_size[1]] = 1
|
||||
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
||||
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
|
||||
img_sizes = [[pH, pW]] * B
|
||||
x_masks = None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x
|
||||
timesteps = t
|
||||
pooled_embeds = y
|
||||
T5_encoder_hidden_states = context
|
||||
|
||||
img_sizes = None
|
||||
|
||||
# spatial forward
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
|
||||
# 0. time
|
||||
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
adaln_input = timesteps + p_embedder
|
||||
|
||||
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
||||
if image_tokens_masks is None:
|
||||
pH, pW = img_sizes[0]
|
||||
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
||||
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
||||
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||
new_encoder_hidden_states.append(enc_hidden_state)
|
||||
encoder_hidden_states = new_encoder_hidden_states
|
||||
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
||||
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(T5_encoder_hidden_states)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
batch_size,
|
||||
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
|
||||
3,
|
||||
device=img_ids.device, dtype=img_ids.dtype
|
||||
)
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
rope = self.pe_embedder(ids)
|
||||
|
||||
# 2. Blocks
|
||||
block_id = 0
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
image_tokens = hidden_states,
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
image_tokens_seq_len = hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||
hidden_states_seq_len = hidden_states.shape[1]
|
||||
if image_tokens_masks is not None:
|
||||
encoder_attention_mask_ones = torch.ones(
|
||||
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
|
||||
)
|
||||
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
||||
|
||||
for bid, block in enumerate(self.single_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
hidden_states = block(
|
||||
image_tokens=hidden_states,
|
||||
image_tokens_masks=image_tokens_masks,
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||
output = self.final_layer(hidden_states, adaln_input)
|
||||
output = self.unpatchify(output, img_sizes)
|
||||
return -output
|
@ -1,135 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=64,
|
||||
context_in_dim=1536,
|
||||
hidden_size=1024,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=16,
|
||||
depth=16,
|
||||
depth_single_blocks=32,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||
)
|
||||
|
||||
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
|
||||
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
|
||||
)
|
||||
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
)
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
img = self.latent_in(x)
|
||||
|
||||
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
|
||||
if self.guidance_in is not None:
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
|
||||
|
||||
txt = self.cond_in(txt)
|
||||
pe = None
|
||||
attn_mask = None
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
return img.movedim(-2, -1) * (-1.0)
|
@ -1,587 +0,0 @@
|
||||
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
|
||||
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
], here f_i is the frequency.
|
||||
|
||||
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||
|
||||
Args:
|
||||
num_freqs (int): the number of frequencies, default is 6;
|
||||
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||
input_dim (int): the input dimension, default is 3;
|
||||
include_input (bool): include the input tensor or not, default is True.
|
||||
|
||||
Attributes:
|
||||
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||
|
||||
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||
otherwise, it is input_dim * num_freqs * 2.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_freqs: int = 6,
|
||||
logspace: bool = True,
|
||||
input_dim: int = 3,
|
||||
include_input: bool = True,
|
||||
include_pi: bool = True) -> None:
|
||||
|
||||
"""The initialization"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (num_freqs - 1),
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if include_pi:
|
||||
frequencies *= torch.pi
|
||||
|
||||
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||
self.include_input = include_input
|
||||
self.num_freqs = num_freqs
|
||||
|
||||
self.out_dim = self.get_dims(input_dim)
|
||||
|
||||
def get_dims(self, input_dim):
|
||||
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||
|
||||
return out_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Forward process.
|
||||
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
|
||||
Returns:
|
||||
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||
where temp is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
|
||||
if self.num_freqs > 0:
|
||||
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
|
||||
if self.include_input:
|
||||
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if self.drop_prob == 0. or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, *,
|
||||
width: int,
|
||||
expand_ratio: int = 4,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = ops.Linear(width, width * expand_ratio)
|
||||
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
kv_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = ops.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.kv_cache = kv_cache
|
||||
self.data = None
|
||||
|
||||
def forward(self, x, data):
|
||||
x = self.c_q(x)
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logging.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
x = self.attention(x, data)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.heads // 3
|
||||
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x)
|
||||
x = self.drop_path(self.c_proj(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for block in self.resblocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
downsample_ratio: int = 1,
|
||||
enable_ln_post: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_ln_post = enable_ln_post
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
mlp_expand_ratio=mlp_expand_ratio,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if self.enable_ln_post:
|
||||
self.ln_post = ops.LayerNorm(width)
|
||||
self.output_proj = ops.Linear(width, out_channels)
|
||||
self.label_type = label_type
|
||||
self.count = 0
|
||||
|
||||
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||
self.count += query_embeddings.shape[1]
|
||||
if self.downsample_ratio != 1:
|
||||
latents = self.latents_proj(latents)
|
||||
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||
if self.enable_ln_post:
|
||||
x = self.ln_post(x)
|
||||
occ = self.output_proj(x)
|
||||
return occ
|
||||
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||
downsample_ratio=geo_decoder_downsample_ratio,
|
||||
enable_ln_post=self.geo_decoder_ln_post,
|
||||
width=width // geo_decoder_downsample_ratio,
|
||||
heads=heads // geo_decoder_downsample_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
)
|
||||
|
||||
self.volume_decoder = VanillaVolumeDecoder()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def decode(self, latents, **kwargs):
|
||||
latents = self.post_kl(latents.movedim(-2, -1))
|
||||
latents = self.transformer(latents)
|
||||
|
||||
bounds = kwargs.get("bounds", 1.01)
|
||||
num_chunks = kwargs.get("num_chunks", 8000)
|
||||
octree_resolution = kwargs.get("octree_resolution", 256)
|
||||
enable_pbar = kwargs.get("enable_pbar", True)
|
||||
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
@ -244,11 +244,9 @@ class HunyuanVideo(nn.Module):
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is not None:
|
||||
@ -275,14 +273,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -297,10 +295,10 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
@ -24,13 +24,6 @@ if model_management.sage_attention_enabled():
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@ -471,7 +464,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
tensor_layout="HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
@ -479,7 +472,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout = "NHD"
|
||||
tensor_layout="NHD"
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
@ -489,17 +482,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
if tensor_layout == "NHD":
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
@ -513,63 +496,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
@ -578,9 +504,6 @@ if model_management.sage_attention_enabled():
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
@ -847,7 +770,6 @@ class SpatialTransformer(nn.Module):
|
||||
if not isinstance(context, list):
|
||||
context = [context] * len(self.transformer_blocks)
|
||||
b, c, h, w = x.shape
|
||||
transformer_options["activations_shape"] = list(x.shape)
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
@ -963,7 +885,6 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
transformer_options={}
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
transformer_options["activations_shape"] = list(x.shape)
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
|
@ -384,7 +384,6 @@ class WanModel(torch.nn.Module):
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@ -424,18 +423,14 @@ class WanModel(torch.nn.Module):
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
freqs=freqs,
|
||||
context=context)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@ -444,7 +439,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
@ -458,7 +453,7 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
@ -12,13 +11,7 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||
return convert_lora_wan_fun(sd)
|
||||
return sd
|
||||
|
@ -36,8 +36,6 @@ import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -60,7 +58,6 @@ class ModelType(Enum):
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
@ -91,8 +88,6 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = comfy.model_sampling.IMG_TO_IMG
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -144,7 +139,6 @@ class BaseModel(torch.nn.Module):
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
@ -606,19 +600,6 @@ class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
else:
|
||||
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
||||
|
||||
class Lotus(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
device = kwargs["device"]
|
||||
task_emb = torch.tensor([1, 0]).float().to(device)
|
||||
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
|
||||
out['y'] = comfy.conds.CONDRegular(task_emb)
|
||||
return out
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
@ -992,42 +973,29 @@ class WAN21(BaseModel):
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||
if extra_channels == 0:
|
||||
if not self.image_to_video:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], 16):
|
||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||
return image
|
||||
|
||||
if image.shape[1] > (extra_channels - 4):
|
||||
image = image[:, :(extra_channels - 4)]
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = self.process_latent_in(image)
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :4]
|
||||
else:
|
||||
if mask.shape[1] != 4:
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = 1.0 - mask
|
||||
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((mask, image), dim=1)
|
||||
@ -1042,35 +1010,3 @@ class WAN21(BaseModel):
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guidance = kwargs.get("guidance", 5.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||
if conditioning_llama3 is not None:
|
||||
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||
return out
|
||||
|
@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
@ -323,40 +323,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "t2v"
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2"
|
||||
dit_config["in_channels"] = in_shape[1]
|
||||
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hidream"
|
||||
dit_config["attention_head_dim"] = 128
|
||||
dit_config["axes_dims_rope"] = [64, 32, 32]
|
||||
dit_config["caption_channels"] = [4096, 4096]
|
||||
dit_config["max_resolution"] = [128, 128]
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
|
||||
dit_config["num_attention_heads"] = 20
|
||||
dit_config["num_routed_experts"] = 4
|
||||
dit_config["num_activated_experts"] = 2
|
||||
dit_config["num_layers"] = 16
|
||||
dit_config["num_single_layers"] = 32
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["text_emb_dim"] = 2048
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@ -701,13 +667,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
@ -46,32 +46,6 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
float8_types.append(torch.float8_e4m3fn)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e4m3fnuz)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e5m2)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e5m2fnuz)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e8m0fnu)
|
||||
except:
|
||||
pass
|
||||
return float8_types
|
||||
|
||||
FLOAT8_TYPES = get_supported_float8_types()
|
||||
|
||||
xpu_available = False
|
||||
torch_version = ""
|
||||
try:
|
||||
@ -212,21 +186,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
mac_ver = mac_version()
|
||||
if mac_ver is not None:
|
||||
logging.info("Mac Version {}".format(mac_ver))
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -727,8 +692,11 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
return torch.float8_e5m2
|
||||
|
||||
fp8_dtype = None
|
||||
if weight_dtype in FLOAT8_TYPES:
|
||||
fp8_dtype = weight_dtype
|
||||
try:
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
fp8_dtype = weight_dtype
|
||||
except:
|
||||
pass
|
||||
|
||||
if fp8_dtype is not None:
|
||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||
@ -823,8 +791,6 @@ def text_encoder_dtype(device=None):
|
||||
return torch.float8_e5m2
|
||||
elif args.fp16_text_enc:
|
||||
return torch.float16
|
||||
elif args.bf16_text_enc:
|
||||
return torch.bfloat16
|
||||
elif args.fp32_text_enc:
|
||||
return torch.float32
|
||||
|
||||
@ -955,9 +921,6 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
def flash_attention_enabled():
|
||||
return args.use_flash_attention
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@ -1006,6 +969,12 @@ def pytorch_attention_flash_attention():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
return False
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
@ -1237,8 +1206,6 @@ def soft_empty_cache(force=False):
|
||||
torch.xpu.empty_cache()
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache()
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
@ -747,7 +747,6 @@ class ModelPatcher:
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
@ -771,10 +770,6 @@ class ModelPatcher:
|
||||
move_weight = False
|
||||
break
|
||||
|
||||
if not hooks_unpatched:
|
||||
self.unpatch_hooks()
|
||||
hooks_unpatched = True
|
||||
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
@ -1094,6 +1089,7 @@ class ModelPatcher:
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
if hooks is not None:
|
||||
model_sd_keys = list(self.model_state_dict().keys())
|
||||
memory_counter = None
|
||||
@ -1104,16 +1100,12 @@ class ModelPatcher:
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
model_sd_keys_set = set(model_sd_keys)
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
model_sd_keys_set.remove(key)
|
||||
self.unpatch_hooks(model_sd_keys_set)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
@ -1124,8 +1116,6 @@ class ModelPatcher:
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
self.current_hooks = hooks
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
@ -1182,23 +1172,17 @@ class ModelPatcher:
|
||||
del out_weight
|
||||
del weight
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
def unpatch_hooks(self) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
self.current_hooks = None
|
||||
return
|
||||
keys = list(self.hook_backup.keys())
|
||||
if whitelist_keys_set:
|
||||
for k in keys:
|
||||
if k in whitelist_keys_set:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
self.hook_backup.pop(k)
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
||||
def clean_hooks(self):
|
||||
self.unpatch_hooks()
|
||||
|
@ -69,15 +69,6 @@ class CONST:
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class X0(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
51
comfy/ops.py
51
comfy/ops.py
@ -21,7 +21,6 @@ import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
@ -147,25 +146,6 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -263,9 +243,6 @@ class manual_cast(disable_weight_init):
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class RMSNorm(disable_weight_init.RMSNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
@ -380,25 +357,6 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
CUBLAS_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if CUBLAS_IS_AVAILABLE:
|
||||
class cublas_ops(disable_weight_init):
|
||||
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
return super().forward(input)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
@ -411,15 +369,6 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
||||
):
|
||||
return fp8_ops
|
||||
|
||||
if (
|
||||
PerformanceFeature.CublasOps in args.fast and
|
||||
CUBLAS_IS_AVAILABLE and
|
||||
weight_dtype == torch.float16 and
|
||||
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||
):
|
||||
logging.info("Using cublas ops")
|
||||
return cublas_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
|
||||
|
@ -48,7 +48,6 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
|
||||
|
||||
class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
PREPARE_SAMPLING = "prepare_sampling"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
|
@ -1,54 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
if RMSNorm is None:
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=None,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
@ -106,13 +106,6 @@ def cleanup_additional_models(models):
|
||||
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
|
@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
||||
"gradient_estimation"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
45
comfy/sd.py
45
comfy/sd.py
@ -14,7 +14,6 @@ import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@ -41,7 +40,6 @@ import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -266,7 +264,6 @@ class VAE:
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
@ -339,7 +336,6 @@ class VAE:
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.disable_offload = True
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
@ -416,17 +412,6 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -455,10 +440,6 @@ class VAE:
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
@ -513,19 +494,18 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
@ -545,9 +525,8 @@ class VAE:
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
@ -574,14 +553,13 @@ class VAE:
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
@ -607,7 +585,6 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
@ -615,7 +592,7 @@ class VAE:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
@ -854,9 +831,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif len(clip_data) == 4:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
@ -925,12 +899,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
|
@ -82,8 +82,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden",
|
||||
"all"
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
@ -94,8 +93,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
if "model_name" not in model_options:
|
||||
model_options = {**model_options, "model_name": "clip_l"}
|
||||
|
||||
if isinstance(textmodel_json_config, dict):
|
||||
config = textmodel_json_config
|
||||
@ -103,10 +100,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
|
||||
for k, v in te_model_options.items():
|
||||
config[k] = v
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
|
||||
@ -154,9 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if self.layer == "all":
|
||||
pass
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
@ -253,12 +244,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
if self.layer == "all":
|
||||
intermediate_output = "all"
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0].float()
|
||||
@ -461,7 +447,7 @@ class SDTokenizer:
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
|
||||
@ -659,7 +645,6 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||
model_options = {**model_options, "model_name": self.clip}
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
|
@ -9,7 +9,6 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
model_options = {**model_options, "model_name": "clip_g"}
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
@ -18,13 +17,14 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
@ -41,7 +41,8 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
@ -74,7 +75,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -83,7 +84,6 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
model_options = {**model_options, "model_name": "clip_g"}
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
|
@ -506,22 +506,6 @@ class SDXL_instructpix2pix(SDXL):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
class LotusD(SD20):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"use_temporal_attention": False,
|
||||
"adm_in_channels": 4,
|
||||
"in_channels": 4,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_classes": 'sequential'
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Lotus(self, device=device)
|
||||
|
||||
class SD3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"in_channels": 16,
|
||||
@ -969,92 +953,12 @@ class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "i2v",
|
||||
"in_dim": 36,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_FunControl2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "i2v",
|
||||
"in_dim": 48,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 3.5
|
||||
|
||||
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Hunyuan3Dv2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
"depth": 8,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
class HiDream(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
}
|
||||
|
||||
# memory_usage_factor = 1.2 # TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HiDream(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
|
||||
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
||||
|
||||
|
||||
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
@ -9,13 +9,14 @@ import os
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
@ -34,7 +35,8 @@ class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
|
@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
@ -1,150 +0,0 @@
|
||||
from . import hunyuan_video
|
||||
from . import sd3_clip
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
class HiDreamTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data)
|
||||
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class HiDreamTEModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
if llama:
|
||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
if "vocab_size" not in model_options:
|
||||
model_options["vocab_size"] = 128256
|
||||
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
|
||||
self.dtypes.add(dtype_llama)
|
||||
else:
|
||||
self.llama = None
|
||||
|
||||
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.set_clip_options(options)
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.set_clip_options(options)
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.set_clip_options(options)
|
||||
if self.llama is not None:
|
||||
self.llama.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.reset_clip_options()
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.reset_clip_options()
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.reset_clip_options()
|
||||
if self.llama is not None:
|
||||
self.llama.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
lg_out = None
|
||||
pooled = None
|
||||
extra = {}
|
||||
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
if self.clip_l is not None:
|
||||
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
else:
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
t5_out, t5_pooled = t5_output[:2]
|
||||
|
||||
if self.llama is not None:
|
||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
ll_out, ll_pooled = ll_output[:2]
|
||||
ll_out = ll_out[:, 1:]
|
||||
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if ll_out is None:
|
||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
extra["conditioning_llama3"] = ll_out
|
||||
return t5_out, pooled, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return self.clip_g.load_sd(sd)
|
||||
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
else:
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||
class HiDreamTEModel_(HiDreamTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HiDreamTEModel_
|
@ -21,31 +21,26 @@ def llama_detect(state_dict, prefix=""):
|
||||
|
||||
|
||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
|
||||
textmodel_json_config = {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
if vocab_size is not None:
|
||||
textmodel_json_config["vocab_size"] = vocab_size
|
||||
|
||||
model_options = {**model_options, "model_name": "llama"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class HunyuanVideoTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||
out = {}
|
||||
@ -77,7 +72,8 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
|
@ -9,26 +9,24 @@ import torch
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
model_options = {**model_options, "model_name": "hydit_clip"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
model_options = {**model_options, "model_name": "mt5xl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
@ -37,7 +35,7 @@ class HyditTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
|
@ -268,17 +268,11 @@ class Llama2_(nn.Module):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
@ -289,12 +283,6 @@ class Llama2_(nn.Module):
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
|
@ -1,27 +1,30 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||
|
||||
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||
if w is None:
|
||||
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
|
||||
else:
|
||||
model_name = "clip_g"
|
||||
|
||||
if w is None:
|
||||
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
model_name = "clip_g"
|
||||
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
model_name = "clip_l"
|
||||
else:
|
||||
model_name = "clip_l"
|
||||
|
||||
if w is not None:
|
||||
if w is not None and w.shape[0] == 248:
|
||||
tokenizer_data = tokenizer_data.copy()
|
||||
model_options = model_options.copy()
|
||||
model_config = model_options.get("model_config", {})
|
||||
model_config["max_position_embeddings"] = w.shape[0]
|
||||
model_options["{}_model_config".format(model_name)] = model_config
|
||||
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
|
||||
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||
model_options["clip_l_class"] = LongClipModel_
|
||||
return tokenizer_data, model_options
|
||||
|
@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
@ -6,7 +6,7 @@ import comfy.text_encoders.llama
|
||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||
|
||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
@ -15,7 +15,6 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
model_options = {**model_options, "model_name": "t5xxl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
@ -32,16 +31,17 @@ def t5_xxl_detect(state_dict, prefix=""):
|
||||
return out
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
@ -61,7 +61,8 @@ class SD3ClipModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
|
||||
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
@ -316,156 +316,3 @@ class LRUCache(BasicCache):
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
|
@ -1,45 +0,0 @@
|
||||
import torch
|
||||
|
||||
# https://github.com/WeichenFan/CFG-Zero-star
|
||||
def optimized_scale(positive, negative):
|
||||
positive_flat = positive.reshape(positive.shape[0], -1)
|
||||
negative_flat = negative.reshape(negative.shape[0], -1)
|
||||
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
||||
|
||||
class CFGZeroStar:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
def cfg_zero_star(args):
|
||||
guidance_scale = args['cond_scale']
|
||||
x = args['input']
|
||||
cond_p = args['cond_denoised']
|
||||
uncond_p = args['uncond_denoised']
|
||||
out = args["denoised"]
|
||||
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
||||
|
||||
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CFGZeroStar": CFGZeroStar
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class QuadrupleCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||
}
|
@ -1,634 +0,0 @@
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class EmptyLatentHunyuan3Dv2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def generate(self, resolution, batch_size):
|
||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||
|
||||
|
||||
class Hunyuan3Dv2Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision_output):
|
||||
embeds = clip_vision_output.last_hidden_state
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class Hunyuan3Dv2ConditioningMultiView:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {},
|
||||
"optional": {"front": ("CLIP_VISION_OUTPUT",),
|
||||
"left": ("CLIP_VISION_OUTPUT",),
|
||||
"back": ("CLIP_VISION_OUTPUT",),
|
||||
"right": ("CLIP_VISION_OUTPUT",), }}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, front=None, left=None, back=None, right=None):
|
||||
all_embeds = [front, left, back, right]
|
||||
out = []
|
||||
pos_embeds = None
|
||||
for i, e in enumerate(all_embeds):
|
||||
if e is not None:
|
||||
if pos_embeds is None:
|
||||
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
|
||||
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
|
||||
|
||||
embeds = torch.cat(out, dim=1)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
||||
class VAEDecodeHunyuan3D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ),
|
||||
"vae": ("VAE", ),
|
||||
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
|
||||
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
|
||||
}}
|
||||
RETURN_TYPES = ("VOXEL",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
|
||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
voxels = voxels.to(device)
|
||||
|
||||
binary = (voxels > threshold).float()
|
||||
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
||||
|
||||
D, H, W = binary.shape
|
||||
|
||||
neighbors = torch.tensor([
|
||||
[0, 0, 1],
|
||||
[0, 0, -1],
|
||||
[0, 1, 0],
|
||||
[0, -1, 0],
|
||||
[1, 0, 0],
|
||||
[-1, 0, 0]
|
||||
], device=device)
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(D, device=device),
|
||||
torch.arange(H, device=device),
|
||||
torch.arange(W, device=device),
|
||||
indexing='ij'
|
||||
)
|
||||
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
||||
|
||||
solid_mask = binary.flatten() > 0
|
||||
solid_indices = voxel_indices[solid_mask]
|
||||
|
||||
corner_offsets = [
|
||||
torch.tensor([
|
||||
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
|
||||
], device=device)
|
||||
]
|
||||
|
||||
all_vertices = []
|
||||
all_indices = []
|
||||
|
||||
vertex_count = 0
|
||||
|
||||
for face_idx, offset in enumerate(neighbors):
|
||||
neighbor_indices = solid_indices + offset
|
||||
|
||||
padded_indices = neighbor_indices + 1
|
||||
|
||||
is_exposed = padded[
|
||||
padded_indices[:, 0],
|
||||
padded_indices[:, 1],
|
||||
padded_indices[:, 2]
|
||||
] == 0
|
||||
|
||||
if not is_exposed.any():
|
||||
continue
|
||||
|
||||
exposed_indices = solid_indices[is_exposed]
|
||||
|
||||
corners = corner_offsets[face_idx].unsqueeze(0)
|
||||
|
||||
face_vertices = exposed_indices.unsqueeze(1) + corners
|
||||
|
||||
all_vertices.append(face_vertices.reshape(-1, 3))
|
||||
|
||||
num_faces = exposed_indices.shape[0]
|
||||
face_indices = torch.arange(
|
||||
vertex_count,
|
||||
vertex_count + 4 * num_faces,
|
||||
device=device
|
||||
).reshape(-1, 4)
|
||||
|
||||
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
|
||||
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
|
||||
|
||||
vertex_count += 4 * num_faces
|
||||
|
||||
if len(all_vertices) > 0:
|
||||
vertices = torch.cat(all_vertices, dim=0)
|
||||
faces = torch.cat(all_indices, dim=0)
|
||||
else:
|
||||
vertices = torch.zeros((1, 3))
|
||||
faces = torch.zeros((1, 3))
|
||||
|
||||
v_min = 0
|
||||
v_max = max(voxels.shape)
|
||||
|
||||
vertices = vertices - (v_min + v_max) / 2
|
||||
|
||||
scale = (v_max - v_min) / 2
|
||||
if scale > 0:
|
||||
vertices = vertices / scale
|
||||
|
||||
vertices = torch.fliplr(vertices)
|
||||
return vertices, faces
|
||||
|
||||
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
voxels = voxels.to(device)
|
||||
|
||||
D, H, W = voxels.shape
|
||||
|
||||
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(D, device=device),
|
||||
torch.arange(H, device=device),
|
||||
torch.arange(W, device=device),
|
||||
indexing='ij'
|
||||
)
|
||||
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
||||
|
||||
corner_offsets = torch.tensor([
|
||||
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
|
||||
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
||||
], device=device)
|
||||
|
||||
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
|
||||
for c, (dz, dy, dx) in enumerate(corner_offsets):
|
||||
corner_values[:, c] = padded[
|
||||
cell_positions[:, 0] + dz,
|
||||
cell_positions[:, 1] + dy,
|
||||
cell_positions[:, 2] + dx
|
||||
]
|
||||
|
||||
corner_signs = corner_values > threshold
|
||||
has_inside = torch.any(corner_signs, dim=1)
|
||||
has_outside = torch.any(~corner_signs, dim=1)
|
||||
contains_surface = has_inside & has_outside
|
||||
|
||||
active_cells = cell_positions[contains_surface]
|
||||
active_signs = corner_signs[contains_surface]
|
||||
active_values = corner_values[contains_surface]
|
||||
|
||||
if active_cells.shape[0] == 0:
|
||||
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
||||
|
||||
edges = torch.tensor([
|
||||
[0, 1], [0, 2], [0, 4], [1, 3],
|
||||
[1, 5], [2, 3], [2, 6], [3, 7],
|
||||
[4, 5], [4, 6], [5, 7], [6, 7]
|
||||
], device=device)
|
||||
|
||||
cell_vertices = {}
|
||||
progress = comfy.utils.ProgressBar(100)
|
||||
|
||||
for edge_idx, (e1, e2) in enumerate(edges):
|
||||
progress.update(1)
|
||||
crossing = active_signs[:, e1] != active_signs[:, e2]
|
||||
if not crossing.any():
|
||||
continue
|
||||
|
||||
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
|
||||
|
||||
v1 = active_values[cell_indices, e1]
|
||||
v2 = active_values[cell_indices, e2]
|
||||
|
||||
t = torch.zeros_like(v1, device=device)
|
||||
denom = v2 - v1
|
||||
valid = denom != 0
|
||||
t[valid] = (threshold - v1[valid]) / denom[valid]
|
||||
t[~valid] = 0.5
|
||||
|
||||
p1 = corner_offsets[e1].float()
|
||||
p2 = corner_offsets[e2].float()
|
||||
|
||||
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
|
||||
|
||||
for i, point in zip(cell_indices.tolist(), intersection):
|
||||
if i not in cell_vertices:
|
||||
cell_vertices[i] = []
|
||||
cell_vertices[i].append(point)
|
||||
|
||||
# Calculate the final vertices as the average of intersection points for each cell
|
||||
vertices = []
|
||||
vertex_lookup = {}
|
||||
|
||||
vert_progress_mod = round(len(cell_vertices)/50)
|
||||
|
||||
for i, points in cell_vertices.items():
|
||||
if not i % vert_progress_mod:
|
||||
progress.update(1)
|
||||
|
||||
if points:
|
||||
vertex = torch.stack(points).mean(dim=0)
|
||||
vertex = vertex + active_cells[i].float()
|
||||
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
|
||||
vertices.append(vertex)
|
||||
|
||||
if not vertices:
|
||||
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
||||
|
||||
final_vertices = torch.stack(vertices)
|
||||
|
||||
inside_corners_mask = active_signs
|
||||
outside_corners_mask = ~active_signs
|
||||
|
||||
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
|
||||
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
|
||||
|
||||
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
||||
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
||||
|
||||
for i in range(8):
|
||||
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
|
||||
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
|
||||
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
|
||||
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
|
||||
|
||||
inside_pos /= inside_counts
|
||||
outside_pos /= outside_counts
|
||||
gradients = inside_pos - outside_pos
|
||||
|
||||
pos_dirs = torch.tensor([
|
||||
[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]
|
||||
], device=device)
|
||||
|
||||
cross_products = [
|
||||
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
|
||||
for i in range(3) for j in range(i+1, 3)
|
||||
]
|
||||
|
||||
faces = []
|
||||
all_keys = set(vertex_lookup.keys())
|
||||
|
||||
face_progress_mod = round(len(active_cells)/38*3)
|
||||
|
||||
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
|
||||
dir_i = pos_dirs[i]
|
||||
dir_j = pos_dirs[j]
|
||||
cross_product = cross_products[pair_idx]
|
||||
|
||||
ni_positions = active_cells + dir_i
|
||||
nj_positions = active_cells + dir_j
|
||||
diag_positions = active_cells + dir_i + dir_j
|
||||
|
||||
alignments = torch.matmul(gradients, cross_product)
|
||||
|
||||
valid_quads = []
|
||||
quad_indices = []
|
||||
|
||||
for idx, active_cell in enumerate(active_cells):
|
||||
if not idx % face_progress_mod:
|
||||
progress.update(1)
|
||||
cell_key = tuple(active_cell.tolist())
|
||||
ni_key = tuple(ni_positions[idx].tolist())
|
||||
nj_key = tuple(nj_positions[idx].tolist())
|
||||
diag_key = tuple(diag_positions[idx].tolist())
|
||||
|
||||
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
|
||||
v0 = vertex_lookup[cell_key]
|
||||
v1 = vertex_lookup[ni_key]
|
||||
v2 = vertex_lookup[nj_key]
|
||||
v3 = vertex_lookup[diag_key]
|
||||
|
||||
valid_quads.append((v0, v1, v2, v3))
|
||||
quad_indices.append(idx)
|
||||
|
||||
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
|
||||
cell_idx = quad_indices[q_idx]
|
||||
if alignments[cell_idx] > 0:
|
||||
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
|
||||
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
|
||||
else:
|
||||
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
|
||||
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
|
||||
|
||||
if faces:
|
||||
faces = torch.stack(faces)
|
||||
else:
|
||||
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
|
||||
|
||||
v_min = 0
|
||||
v_max = max(D, H, W)
|
||||
|
||||
final_vertices = final_vertices - (v_min + v_max) / 2
|
||||
|
||||
scale = (v_max - v_min) / 2
|
||||
if scale > 0:
|
||||
final_vertices = final_vertices / scale
|
||||
|
||||
final_vertices = torch.fliplr(final_vertices)
|
||||
|
||||
return final_vertices, faces
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices, faces):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class VoxelToMeshBasic:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, threshold):
|
||||
vertices = []
|
||||
faces = []
|
||||
for x in voxel.data:
|
||||
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
|
||||
class VoxelToMesh:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"algorithm": (["surface net", "basic"], ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, algorithm, threshold):
|
||||
vertices = []
|
||||
faces = []
|
||||
|
||||
if algorithm == "basic":
|
||||
mesh_function = voxel_to_mesh
|
||||
elif algorithm == "surface net":
|
||||
mesh_function = voxel_to_mesh_surfnet
|
||||
|
||||
for x in voxel.data:
|
||||
v, f = mesh_function(x, threshold=threshold, device=None)
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
Parameters:
|
||||
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
|
||||
filepath: str - Output filepath (should end with .glb)
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
||||
faces_np = faces.cpu().numpy().astype(np.uint32)
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b'\x00' * padding_length
|
||||
|
||||
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||
|
||||
buffer_data = vertices_buffer_padded + indices_buffer_padded
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
vertices_byte_offset = 0
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
"buffers": [
|
||||
{
|
||||
"byteLength": len(buffer_data)
|
||||
}
|
||||
],
|
||||
"bufferViews": [
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": vertices_byte_offset,
|
||||
"byteLength": vertices_byte_length,
|
||||
"target": 34962 # ARRAY_BUFFER
|
||||
},
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": indices_byte_offset,
|
||||
"byteLength": indices_byte_length,
|
||||
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
||||
}
|
||||
],
|
||||
"accessors": [
|
||||
{
|
||||
"bufferView": 0,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126, # FLOAT
|
||||
"count": len(vertices_np),
|
||||
"type": "VEC3",
|
||||
"max": vertices_np.max(axis=0).tolist(),
|
||||
"min": vertices_np.min(axis=0).tolist()
|
||||
},
|
||||
{
|
||||
"bufferView": 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5125, # UNSIGNED_INT
|
||||
"count": faces_np.size,
|
||||
"type": "SCALAR"
|
||||
}
|
||||
],
|
||||
"meshes": [
|
||||
{
|
||||
"primitives": [
|
||||
{
|
||||
"attributes": {
|
||||
"POSITION": 0
|
||||
},
|
||||
"indices": 1,
|
||||
"mode": 4 # TRIANGLES
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"mesh": 0
|
||||
}
|
||||
],
|
||||
"scenes": [
|
||||
{
|
||||
"nodes": [0]
|
||||
}
|
||||
],
|
||||
"scene": 0
|
||||
}
|
||||
|
||||
if metadata is not None:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
|
||||
# Convert the JSON to bytes
|
||||
gltf_json = json.dumps(gltf).encode('utf8')
|
||||
|
||||
def pad_json_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b' ' * padding_length
|
||||
|
||||
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
||||
|
||||
# Create the GLB header
|
||||
# Magic glTF
|
||||
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
||||
|
||||
# Create JSON chunk header (chunk type 0)
|
||||
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
||||
|
||||
# Create BIN chunk header (chunk type 1)
|
||||
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
||||
|
||||
# Write the GLB file
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(glb_header)
|
||||
f.write(json_chunk_header)
|
||||
f.write(gltf_json_padded)
|
||||
f.write(bin_chunk_header)
|
||||
f.write(buffer_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"mesh": ("MESH", ),
|
||||
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return {"ui": {"3d": results}}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
|
||||
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
|
||||
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||
"VoxelToMesh": VoxelToMesh,
|
||||
"SaveGLB": SaveGLB,
|
||||
}
|
@ -19,10 +19,12 @@ class Load3D():
|
||||
"image": ("LOAD_3D", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@ -32,16 +34,12 @@ class Load3D():
|
||||
def process(self, model_file, image, **kwargs):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image
|
||||
return output_image, output_mask, model_file,
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@ -57,10 +55,12 @@ class Load3DAnimation():
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@ -70,20 +70,20 @@ class Load3DAnimation():
|
||||
def process(self, model_file, image, **kwargs):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image
|
||||
return output_image, output_mask, model_file,
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@ -102,6 +102,8 @@ class Preview3DAnimation():
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
File diff suppressed because one or more lines are too long
@ -99,13 +99,12 @@ class LTXVAddGuide:
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"latent": ("LATENT",),
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
|
||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
|
||||
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
|
||||
"Negative values are counted from the end of the video."}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
@ -128,13 +127,12 @@ class LTXVAddGuide:
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1:
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
|
||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
@ -193,7 +191,7 @@ class LTXVAddGuide:
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||
@ -446,9 +444,10 @@ class LTXVPreprocess:
|
||||
CATEGORY = "image"
|
||||
|
||||
def preprocess(self, image, img_compression):
|
||||
output_images = []
|
||||
for i in range(image.shape[0]):
|
||||
output_images.append(preprocess(image[i], img_compression))
|
||||
if img_compression > 0:
|
||||
output_images = []
|
||||
for i in range(image.shape[0]):
|
||||
output_images.append(preprocess(image[i], img_compression))
|
||||
return (torch.stack(output_images),)
|
||||
|
||||
|
||||
|
@ -2,7 +2,6 @@ import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
@ -88,7 +87,6 @@ class ImageCompositeMasked:
|
||||
CATEGORY = "image"
|
||||
|
||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
return (output,)
|
||||
|
@ -20,6 +20,10 @@ class LCM(comfy.model_sampling.EPS):
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class X0(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
@ -52,7 +56,7 @@ class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
|
||||
"zsnr": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
|
||||
@ -73,9 +77,7 @@ class ModelSamplingDiscrete:
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
elif sampling == "x0":
|
||||
sampling_type = comfy.model_sampling.X0
|
||||
elif sampling == "img_to_img":
|
||||
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
||||
sampling_type = X0
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
@ -244,30 +244,6 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["patch_embedding."] = argument
|
||||
arg_dict["time_embedding."] = argument
|
||||
arg_dict["time_projection."] = argument
|
||||
arg_dict["text_embedding."] = argument
|
||||
arg_dict["img_emb."] = argument
|
||||
|
||||
for i in range(40):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["head."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -280,5 +256,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeLTXV": ModelMergeLTXV,
|
||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ import torch
|
||||
import comfy.model_management
|
||||
|
||||
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
|
||||
import kornia.color
|
||||
|
||||
|
||||
class Morphology:
|
||||
@ -41,45 +40,8 @@ class Morphology:
|
||||
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
|
||||
return (img_out,)
|
||||
|
||||
|
||||
class ImageRGBToYUV:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("Y", "U", "V")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def execute(self, image):
|
||||
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
|
||||
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
|
||||
|
||||
class ImageYUVToRGB:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"Y": ("IMAGE",),
|
||||
"U": ("IMAGE",),
|
||||
"V": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def execute(self, Y, U, V):
|
||||
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
|
||||
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
|
||||
return (out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Morphology": Morphology,
|
||||
"ImageRGBToYUV": ImageRGBToYUV,
|
||||
"ImageYUVToRGB": ImageYUVToRGB,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
@ -1,56 +0,0 @@
|
||||
# from https://github.com/bebebe666/OptimalSteps
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""
|
||||
Performs log-linear interpolation of a given array of decreasing numbers.
|
||||
"""
|
||||
xs = np.linspace(0, 1, len(t_steps))
|
||||
ys = np.log(t_steps[::-1])
|
||||
|
||||
new_xs = np.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
interped_ys = np.exp(new_ys)[::-1].copy()
|
||||
return interped_ys
|
||||
|
||||
|
||||
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
|
||||
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
|
||||
}
|
||||
|
||||
class OptimalStepsScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model_type": (["FLUX", "Wan"], ),
|
||||
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model_type, steps, denoise):
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return (torch.FloatTensor([]),)
|
||||
total_steps = round(steps * denoise)
|
||||
|
||||
sigmas = NOISE_LEVELS[model_type][:]
|
||||
if (steps + 1) != len(sigmas):
|
||||
sigmas = loglinear_interp(sigmas, steps + 1)
|
||||
|
||||
sigmas = sigmas[-(total_steps + 1):]
|
||||
sigmas[-1] = 0
|
||||
return (torch.FloatTensor(sigmas), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OptimalStepsScheduler": OptimalStepsScheduler,
|
||||
}
|
@ -6,7 +6,7 @@ import math
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
|
||||
|
||||
class Blend:
|
||||
def __init__(self):
|
||||
@ -34,7 +34,6 @@ class Blend:
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
|
||||
image2 = image2.to(image1.device)
|
||||
if image1.shape != image2.shape:
|
||||
image2 = image2.permute(0, 3, 1, 2)
|
||||
|
@ -1,79 +0,0 @@
|
||||
# Primitive nodes that are evaluated at backend.
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||
|
||||
|
||||
class String(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.STRING, {})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: str) -> tuple[str]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class Int(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.INT, {"control_after_generate": True})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: int) -> tuple[int]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class Float(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.FLOAT, {})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.FLOAT,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: float) -> tuple[float]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class Boolean(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.BOOLEAN, {})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: bool) -> tuple[bool]:
|
||||
return (value,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PrimitiveString": String,
|
||||
"PrimitiveInt": Int,
|
||||
"PrimitiveFloat": Float,
|
||||
"PrimitiveBoolean": Boolean,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PrimitiveString": "String",
|
||||
"PrimitiveInt": "Int",
|
||||
"PrimitiveFloat": "Float",
|
||||
"PrimitiveBoolean": "Boolean",
|
||||
}
|
@ -3,7 +3,6 @@ import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
|
||||
|
||||
class WanImageToVideo:
|
||||
@ -50,110 +49,6 @@ class WanImageToVideo:
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFunControlToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"control_video": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
class WanFunInpaintToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
if end_image is not None:
|
||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
image = torch.ones((length, height, width, 3)) * 0.5
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
if start_image is not None:
|
||||
image[:start_image.shape[0]] = start_image
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
if end_image is not None:
|
||||
image[-end_image.shape[0]:] = end_image
|
||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
}
|
||||
|
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.28"
|
||||
__version__ = "0.3.25"
|
||||
|
62
execution.py
62
execution.py
@ -15,7 +15,7 @@ import nodes
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -59,45 +59,27 @@ class IsChangedCache:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
|
||||
class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
DEPENDENCY_AWARE = 2
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
cache_size = 0
|
||||
self.init_lru_cache(cache_size)
|
||||
logging.info("Using LRU cache")
|
||||
else:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
@ -432,14 +414,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, cache_type=False, cache_size=None):
|
||||
self.cache_size = cache_size
|
||||
self.cache_type = cache_type
|
||||
def __init__(self, server, lru_size=None):
|
||||
self.lru_size = lru_size
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
@ -653,13 +634,6 @@ def validate_inputs(prompt, item, validated):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||
# list widget value to execution, as by default list value is
|
||||
# reserved to represent the connection between nodes.
|
||||
if isinstance(val, dict) and "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
@ -794,7 +768,7 @@ def validate_prompt(prompt):
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], {})
|
||||
return (False, error, [], [])
|
||||
|
||||
class_type = prompt[x]['class_type']
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||||
@ -805,7 +779,7 @@ def validate_prompt(prompt):
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], {})
|
||||
return (False, error, [], [])
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
outputs.add(x)
|
||||
@ -817,7 +791,7 @@ def validate_prompt(prompt):
|
||||
"details": "",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], {})
|
||||
return (False, error, [], [])
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
|
@ -85,7 +85,6 @@ cache_helper = CacheHelper()
|
||||
|
||||
extension_mimetypes_cache = {
|
||||
"webp" : "image",
|
||||
"fbx" : "model",
|
||||
}
|
||||
|
||||
def map_legacy(folder_name: str) -> str:
|
||||
@ -141,14 +140,11 @@ def get_directory_by_type(type_name: str) -> str | None:
|
||||
return get_input_directory()
|
||||
return None
|
||||
|
||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
|
||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
||||
"""
|
||||
Example:
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
videos = filter_files_content_types(files, ["video"])
|
||||
|
||||
Note:
|
||||
- 'model' in MIME context refers to 3D models, not files containing trained weights and parameters
|
||||
filter_files_content_types(files, ["image", "audio", "video"])
|
||||
"""
|
||||
global extension_mimetypes_cache
|
||||
result = []
|
||||
|
28
main.py
28
main.py
@ -10,7 +10,6 @@ from app.logger import setup_logger
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
|
||||
@ -140,7 +139,7 @@ from server import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import app.frontend_management
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
@ -157,13 +156,7 @@ def cuda_malloc_warning():
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@ -300,15 +293,28 @@ def start_comfyui(asyncio_loop=None):
|
||||
return asyncio_loop, prompt_server, start_all
|
||||
|
||||
|
||||
def warn_frontend_version(frontend_version):
|
||||
try:
|
||||
required_frontend = (0,)
|
||||
req_path = os.path.join(os.path.dirname(__file__), 'requirements.txt')
|
||||
with open(req_path, 'r') as f:
|
||||
required_frontend = tuple(map(int, f.readline().split('=')[-1].split('.')))
|
||||
if frontend_version < required_frontend:
|
||||
logging.warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), app.frontend_management.frontend_install_warning_message()))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("Python version: {}".format(sys.version))
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
frontend_version = app.frontend_management.frontend_version
|
||||
logging.info("ComfyUI frontend version: {}".format('.'.join(map(str, frontend_version))))
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
x = start_all_func()
|
||||
app.logger.print_startup_warnings()
|
||||
warn_frontend_version(frontend_version)
|
||||
event_loop.run_until_complete(x)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
@ -44,11 +44,3 @@ def string_to_torch_dtype(string):
|
||||
return torch.float16
|
||||
if string == "bf16":
|
||||
return torch.bfloat16
|
||||
|
||||
def image_alpha_fix(destination, source):
|
||||
if destination.shape[-1] < source.shape[-1]:
|
||||
source = source[...,:destination.shape[-1]]
|
||||
elif destination.shape[-1] > source.shape[-1]:
|
||||
destination = torch.nn.functional.pad(destination, (0, 1))
|
||||
destination[..., -1] = 1.0
|
||||
return destination, source
|
||||
|
38
nodes.py
38
nodes.py
@ -489,7 +489,7 @@ class SaveLatent:
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"].contiguous()
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
output["latent_format_version_0"] = torch.tensor([])
|
||||
|
||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||
@ -770,7 +770,6 @@ class VAELoader:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
@ -786,8 +785,6 @@ class ControlNetLoader:
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
if controlnet is None:
|
||||
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
|
||||
return (controlnet,)
|
||||
|
||||
class DiffControlNetLoader:
|
||||
@ -1008,8 +1005,6 @@ class CLIPVisionLoader:
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||
clip_vision = comfy.clip_vision.load(clip_path)
|
||||
if clip_vision is None:
|
||||
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
||||
return (clip_vision,)
|
||||
|
||||
class CLIPVisionEncode:
|
||||
@ -1654,7 +1649,6 @@ class LoadImage:
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
files = folder_paths.filter_files_content_types(files, ["image"])
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True})},
|
||||
}
|
||||
@ -1693,9 +1687,6 @@ class LoadImage:
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
elif i.mode == 'P' and 'transparency' in i.info:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
@ -1794,7 +1785,14 @@ class LoadImageOutput(LoadImage):
|
||||
|
||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||
EXPERIMENTAL = True
|
||||
FUNCTION = "load_image"
|
||||
FUNCTION = "load_image_output"
|
||||
|
||||
def load_image_output(self, image):
|
||||
return self.load_image(f"{image} [output]")
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
return True
|
||||
|
||||
|
||||
class ImageScale:
|
||||
@ -2131,25 +2129,21 @@ def get_module_name(module_path: str) -> str:
|
||||
|
||||
|
||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = get_module_name(module_path)
|
||||
module_name = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
module_name = sp[0]
|
||||
sys_module_name = module_name
|
||||
elif os.path.isdir(module_path):
|
||||
sys_module_name = module_path.replace(".", "_x_")
|
||||
|
||||
try:
|
||||
logging.debug("Trying to load custom node {}".format(module_path))
|
||||
if os.path.isfile(module_path):
|
||||
module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path)
|
||||
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module_dir = os.path.split(module_path)[0]
|
||||
else:
|
||||
module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py"))
|
||||
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
||||
module_dir = module_path
|
||||
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
sys.modules[sys_module_name] = module
|
||||
sys.modules[module_name] = module
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
@ -2276,12 +2270,6 @@ def init_builtin_extra_nodes():
|
||||
"nodes_video.py",
|
||||
"nodes_lumina2.py",
|
||||
"nodes_wan.py",
|
||||
"nodes_lotus.py",
|
||||
"nodes_hunyuan3d.py",
|
||||
"nodes_primitive.py",
|
||||
"nodes_cfg.py",
|
||||
"nodes_optimalsteps.py",
|
||||
"nodes_hidream.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.28"
|
||||
version = "0.3.25"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.15.13
|
||||
comfyui-frontend-package==1.11.8
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
10
server.py
10
server.py
@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||
if request.path.endswith('.js') or request.path.endswith('.css'):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
@ -657,13 +657,7 @@ class PromptServer():
|
||||
logging.warning("invalid prompt: {}".format(valid[1]))
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
else:
|
||||
error = {
|
||||
"type": "no_prompt",
|
||||
"message": "No prompt provided",
|
||||
"details": "No prompt provided",
|
||||
"extra_info": {}
|
||||
}
|
||||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
|
||||
@routes.post("/queue")
|
||||
async def post_queue(request):
|
||||
|
@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.default_frontend_path()
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
@ -84,29 +84,24 @@ def test_init_frontend_invalid_provider():
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_os_functions():
|
||||
with (
|
||||
patch("app.frontend_management.os.makedirs") as mock_makedirs,
|
||||
patch("app.frontend_management.os.listdir") as mock_listdir,
|
||||
patch("app.frontend_management.os.rmdir") as mock_rmdir,
|
||||
):
|
||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||
mock_listdir.return_value = [] # Simulate empty directory
|
||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_download():
|
||||
with patch("app.frontend_management.download_release_asset_zip") as mock:
|
||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||
yield mock
|
||||
|
||||
|
||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||
# Arrange
|
||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||
version_string = "test-owner/test-repo@1.0.0"
|
||||
version_string = 'test-owner/test-repo@1.0.0'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception):
|
||||
@ -133,42 +128,3 @@ def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_default_with_mocks():
|
||||
# Arrange
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||
),
|
||||
):
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
|
||||
# Assert
|
||||
assert frontend_path == "/mocked/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
|
||||
def test_init_frontend_fallback_on_error():
|
||||
# Arrange
|
||||
version_string = "test-owner/test-repo@1.0.0"
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(
|
||||
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||
),
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||
),
|
||||
):
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
|
||||
# Assert
|
||||
assert frontend_path == "/default/path"
|
||||
mock_check.assert_called_once()
|
||||
|
@ -1,17 +1,14 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from folder_paths import filter_files_content_types, extension_mimetypes_cache
|
||||
from unittest.mock import patch
|
||||
|
||||
from folder_paths import filter_files_content_types
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def file_extensions():
|
||||
return {
|
||||
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'],
|
||||
'model': ['gltf', 'glb', 'obj', 'fbx', 'stl']
|
||||
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||
}
|
||||
|
||||
|
||||
@ -25,18 +22,7 @@ def mock_dir(file_extensions):
|
||||
yield directory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_mimetype_cache(file_extensions):
|
||||
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
|
||||
new_cache = extension_mimetypes_cache.copy()
|
||||
for extension in file_extensions["model"]:
|
||||
new_cache[extension] = "model"
|
||||
|
||||
with patch("folder_paths.extension_mimetypes_cache", new_cache):
|
||||
yield
|
||||
|
||||
|
||||
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
|
||||
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||
files = os.listdir(mock_dir)
|
||||
for content_type, extensions in file_extensions.items():
|
||||
filtered_files = filter_files_content_types(files, [content_type])
|
||||
@ -44,7 +30,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
|
||||
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||
|
||||
|
||||
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
|
||||
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||
files = os.listdir(mock_dir)
|
||||
for content_type, extensions in file_extensions.items():
|
||||
filtered_files = filter_files_content_types(files, [content_type])
|
||||
|
Loading…
Reference in New Issue
Block a user