mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Compare commits
28 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
fd27494441 | ||
![]() |
f43e1d7f41 | ||
![]() |
4486b0d0ff | ||
![]() |
636d4bfb89 | ||
![]() |
dc300a4569 | ||
![]() |
f3b09b9f2d | ||
![]() |
7ecd5e9614 | ||
![]() |
2383a39e3b | ||
![]() |
34e06bf7ec | ||
![]() |
55822faa05 | ||
![]() |
880c205df1 | ||
![]() |
3dc240d089 | ||
![]() |
19373aee75 | ||
![]() |
93292bc450 | ||
![]() |
05d5a75cdc | ||
![]() |
eba7a25e7a | ||
![]() |
dbcfd092a2 | ||
![]() |
c14429940f | ||
![]() |
0d720e4367 | ||
![]() |
1fc00ba4b6 | ||
![]() |
9899d187b1 | ||
![]() |
f00f340a56 | ||
![]() |
cce1d9145e | ||
![]() |
b4dc03ad76 | ||
![]() |
9ad792f927 | ||
![]() |
6fc5dbd52a | ||
![]() |
3e8155f7a3 | ||
![]() |
8a438115fb |
26
CODEOWNERS
26
CODEOWNERS
@ -5,20 +5,20 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
|
@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
|
@ -184,6 +184,27 @@ comfyui-frontend-package is not installed.
|
|||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def templates_path(cls) -> str:
|
||||||
|
try:
|
||||||
|
import comfyui_workflow_templates
|
||||||
|
|
||||||
|
return str(
|
||||||
|
importlib.resources.files(comfyui_workflow_templates) / "templates"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logging.error(
|
||||||
|
f"""
|
||||||
|
********** ERROR ***********
|
||||||
|
|
||||||
|
comfyui-workflow-templates is not installed.
|
||||||
|
|
||||||
|
{frontend_install_warning_message()}
|
||||||
|
|
||||||
|
********** ERROR ***********
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
|
@ -99,59 +99,59 @@ class InputTypeOptions(TypedDict):
|
|||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default: bool | str | float | int | list | tuple
|
default: NotRequired[bool | str | float | int | list | tuple]
|
||||||
"""The default value of the widget"""
|
"""The default value of the widget"""
|
||||||
defaultInput: bool
|
defaultInput: NotRequired[bool]
|
||||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
||||||
- defaultInput on required inputs should be dropped.
|
- defaultInput on required inputs should be dropped.
|
||||||
- defaultInput on optional inputs should be replaced with forceInput.
|
- defaultInput on optional inputs should be replaced with forceInput.
|
||||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
||||||
"""
|
"""
|
||||||
forceInput: bool
|
forceInput: NotRequired[bool]
|
||||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
||||||
lazy: bool
|
lazy: NotRequired[bool]
|
||||||
"""Declares that this input uses lazy evaluation"""
|
"""Declares that this input uses lazy evaluation"""
|
||||||
rawLink: bool
|
rawLink: NotRequired[bool]
|
||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||||
tooltip: str
|
tooltip: NotRequired[str]
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
# class InputTypeNumber(InputTypeOptions):
|
||||||
# default: float | int
|
# default: float | int
|
||||||
min: float
|
min: NotRequired[float]
|
||||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||||
max: float
|
max: NotRequired[float]
|
||||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||||
step: float
|
step: NotRequired[float]
|
||||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||||
round: float
|
round: NotRequired[float]
|
||||||
"""Floats are rounded by this value (``FLOAT``)"""
|
"""Floats are rounded by this value (``FLOAT``)"""
|
||||||
# class InputTypeBoolean(InputTypeOptions):
|
# class InputTypeBoolean(InputTypeOptions):
|
||||||
# default: bool
|
# default: bool
|
||||||
label_on: str
|
label_on: NotRequired[str]
|
||||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||||
label_off: str
|
label_off: NotRequired[str]
|
||||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||||
# class InputTypeString(InputTypeOptions):
|
# class InputTypeString(InputTypeOptions):
|
||||||
# default: str
|
# default: str
|
||||||
multiline: bool
|
multiline: NotRequired[bool]
|
||||||
"""Use a multiline text box (``STRING``)"""
|
"""Use a multiline text box (``STRING``)"""
|
||||||
placeholder: str
|
placeholder: NotRequired[str]
|
||||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||||
# Deprecated:
|
# Deprecated:
|
||||||
# defaultVal: str
|
# defaultVal: str
|
||||||
dynamicPrompts: bool
|
dynamicPrompts: NotRequired[bool]
|
||||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||||
# class InputTypeCombo(InputTypeOptions):
|
# class InputTypeCombo(InputTypeOptions):
|
||||||
image_upload: bool
|
image_upload: NotRequired[bool]
|
||||||
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
||||||
image_folder: Literal["input", "output", "temp"]
|
image_folder: NotRequired[Literal["input", "output", "temp"]]
|
||||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||||
"""
|
"""
|
||||||
remote: RemoteInputOptions
|
remote: NotRequired[RemoteInputOptions]
|
||||||
"""Specifies the configuration for a remote input.
|
"""Specifies the configuration for a remote input.
|
||||||
Available after ComfyUI frontend v1.9.7
|
Available after ComfyUI frontend v1.9.7
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||||
control_after_generate: bool
|
control_after_generate: NotRequired[bool]
|
||||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||||
options: NotRequired[list[str | int | float]]
|
options: NotRequired[list[str | int | float]]
|
||||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||||
@ -169,15 +169,15 @@ class InputTypeOptions(TypedDict):
|
|||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||||
|
|
||||||
node_id: Literal["UNIQUE_ID"]
|
node_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||||
unique_id: Literal["UNIQUE_ID"]
|
unique_id: NotRequired[Literal["UNIQUE_ID"]]
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||||
prompt: Literal["PROMPT"]
|
prompt: NotRequired[Literal["PROMPT"]]
|
||||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
|
||||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||||
dynprompt: Literal["DYNPROMPT"]
|
dynprompt: NotRequired[Literal["DYNPROMPT"]]
|
||||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||||
|
|
||||||
|
|
||||||
@ -187,11 +187,11 @@ class InputTypeDict(TypedDict):
|
|||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||||
"""Describes all inputs that must be connected for the node to execute."""
|
"""Describes all inputs that must be connected for the node to execute."""
|
||||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
||||||
"""Describes inputs which do not need to be connected."""
|
"""Describes inputs which do not need to be connected."""
|
||||||
hidden: HiddenInputTypeDict
|
hidden: NotRequired[HiddenInputTypeDict]
|
||||||
"""Offers advanced functionality and server-client communication.
|
"""Offers advanced functionality and server-client communication.
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.ops
|
import comfy.rmsnorm
|
||||||
|
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
|
|
||||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
rms_norm = comfy.rmsnorm.rms_norm
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
||||||
if weight is None:
|
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
799
comfy/ldm/hidream/model.py
Normal file
799
comfy/ldm/hidream/model.py
Normal file
@ -0,0 +1,799 @@
|
|||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import einops
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from comfy.ldm.flux.math import apply_rope, rope
|
||||||
|
from comfy.ldm.flux.layers import LastLayer
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||||
|
class EmbedND(nn.Module):
|
||||||
|
def __init__(self, theta: int, axes_dim: List[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat(
|
||||||
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
return emb.unsqueeze(2)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=1024,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, latent):
|
||||||
|
latent = self.proj(latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
class PooledEmbed(nn.Module):
|
||||||
|
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, pooled_embed):
|
||||||
|
return self.pooled_embedder(pooled_embed)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbed(nn.Module):
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, timesteps, wdtype):
|
||||||
|
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||||
|
t_emb = self.timestep_embedder(t_emb)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||||
|
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamAttnProcessor_flashattn:
|
||||||
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
image_tokens: torch.FloatTensor,
|
||||||
|
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||||
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
|
rope: torch.FloatTensor = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
dtype = image_tokens.dtype
|
||||||
|
batch_size = image_tokens.shape[0]
|
||||||
|
|
||||||
|
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
|
||||||
|
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
|
||||||
|
value_i = attn.to_v(image_tokens)
|
||||||
|
|
||||||
|
inner_dim = key_i.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
if image_tokens_masks is not None:
|
||||||
|
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
|
||||||
|
|
||||||
|
if not attn.single:
|
||||||
|
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
|
||||||
|
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
|
||||||
|
value_t = attn.to_v_t(text_tokens)
|
||||||
|
|
||||||
|
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
|
||||||
|
num_image_tokens = query_i.shape[1]
|
||||||
|
num_text_tokens = query_t.shape[1]
|
||||||
|
query = torch.cat([query_i, query_t], dim=1)
|
||||||
|
key = torch.cat([key_i, key_t], dim=1)
|
||||||
|
value = torch.cat([value_i, value_t], dim=1)
|
||||||
|
else:
|
||||||
|
query = query_i
|
||||||
|
key = key_i
|
||||||
|
value = value_i
|
||||||
|
|
||||||
|
if query.shape[-1] == rope.shape[-3] * 2:
|
||||||
|
query, key = apply_rope(query, key, rope)
|
||||||
|
else:
|
||||||
|
query_1, query_2 = query.chunk(2, dim=-1)
|
||||||
|
key_1, key_2 = key.chunk(2, dim=-1)
|
||||||
|
query_1, key_1 = apply_rope(query_1, key_1, rope)
|
||||||
|
query = torch.cat([query_1, query_2], dim=-1)
|
||||||
|
key = torch.cat([key_1, key_2], dim=-1)
|
||||||
|
|
||||||
|
hidden_states = attention(query, key, value)
|
||||||
|
|
||||||
|
if not attn.single:
|
||||||
|
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||||
|
hidden_states_i = attn.to_out(hidden_states_i)
|
||||||
|
hidden_states_t = attn.to_out_t(hidden_states_t)
|
||||||
|
return hidden_states_i, hidden_states_t
|
||||||
|
else:
|
||||||
|
hidden_states = attn.to_out(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class HiDreamAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
upcast_softmax: bool = False,
|
||||||
|
scale_qk: bool = True,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
processor = None,
|
||||||
|
out_dim: int = None,
|
||||||
|
single: bool = False,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
# super(Attention, self).__init__()
|
||||||
|
super().__init__()
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.upcast_attention = upcast_attention
|
||||||
|
self.upcast_softmax = upcast_softmax
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
|
||||||
|
self.scale_qk = scale_qk
|
||||||
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||||
|
|
||||||
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||||
|
self.sliceable_head_dim = heads
|
||||||
|
self.single = single
|
||||||
|
|
||||||
|
linear_cls = operations.Linear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
|
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||||
|
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||||
|
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if not single:
|
||||||
|
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
||||||
|
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||||
|
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
norm_image_tokens: torch.FloatTensor,
|
||||||
|
image_tokens_masks: torch.FloatTensor = None,
|
||||||
|
norm_text_tokens: torch.FloatTensor = None,
|
||||||
|
rope: torch.FloatTensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.processor(
|
||||||
|
self,
|
||||||
|
image_tokens = norm_image_tokens,
|
||||||
|
image_tokens_masks = image_tokens_masks,
|
||||||
|
text_tokens = norm_text_tokens,
|
||||||
|
rope = rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardSwiGLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: Optional[float] = None,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * (
|
||||||
|
(hidden_dim + multiple_of - 1) // multiple_of
|
||||||
|
)
|
||||||
|
|
||||||
|
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||||
|
class MoEGate(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.top_k = num_activated_experts
|
||||||
|
self.n_routed_experts = num_routed_experts
|
||||||
|
|
||||||
|
self.scoring_func = 'softmax'
|
||||||
|
self.alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = False
|
||||||
|
|
||||||
|
# topk selection algorithm
|
||||||
|
self.norm_topk_prob = False
|
||||||
|
self.gating_dim = embed_dim
|
||||||
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
pass
|
||||||
|
# import torch.nn.init as init
|
||||||
|
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
bsz, seq_len, h = hidden_states.shape
|
||||||
|
|
||||||
|
### compute gating score
|
||||||
|
hidden_states = hidden_states.view(-1, h)
|
||||||
|
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
||||||
|
if self.scoring_func == 'softmax':
|
||||||
|
scores = logits.softmax(dim=-1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||||
|
|
||||||
|
### select top-k experts
|
||||||
|
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
### norm gate to sum 1
|
||||||
|
if self.top_k > 1 and self.norm_topk_prob:
|
||||||
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
|
topk_weight = topk_weight / denominator
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||||
|
class MOEFeedForwardSwiGLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_routed_experts: int,
|
||||||
|
num_activated_experts: int,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
|
||||||
|
self.gate = MoEGate(
|
||||||
|
embed_dim = dim,
|
||||||
|
num_routed_experts = num_routed_experts,
|
||||||
|
num_activated_experts = num_activated_experts,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.num_activated_experts = num_activated_experts
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
wtype = x.dtype
|
||||||
|
identity = x
|
||||||
|
orig_shape = x.shape
|
||||||
|
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
if True: # self.training: # TODO: check which branch performs faster
|
||||||
|
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||||
|
y = torch.empty_like(x, dtype=wtype)
|
||||||
|
for i, expert in enumerate(self.experts):
|
||||||
|
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
y = y.view(*orig_shape).to(dtype=wtype)
|
||||||
|
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||||
|
else:
|
||||||
|
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||||
|
y = y + self.shared_experts(identity)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||||
|
expert_cache = torch.zeros_like(x)
|
||||||
|
idxs = flat_expert_indices.argsort()
|
||||||
|
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||||
|
token_idxs = idxs // self.num_activated_experts
|
||||||
|
for i, end_idx in enumerate(tokens_per_expert):
|
||||||
|
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||||
|
if start_idx == end_idx:
|
||||||
|
continue
|
||||||
|
expert = self.experts[i]
|
||||||
|
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||||
|
expert_tokens = x[exp_token_idx]
|
||||||
|
expert_out = expert(expert_tokens)
|
||||||
|
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||||
|
|
||||||
|
# for fp16 and other dtype
|
||||||
|
expert_cache = expert_cache.to(expert_out.dtype)
|
||||||
|
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
|
||||||
|
return expert_cache
|
||||||
|
|
||||||
|
|
||||||
|
class TextProjection(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
hidden_states = self.linear(caption)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BlockType:
|
||||||
|
TransformerBlock = 1
|
||||||
|
SingleTransformerBlock = 2
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_routed_experts: int = 4,
|
||||||
|
num_activated_experts: int = 2,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Attention
|
||||||
|
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||||
|
self.attn1 = HiDreamAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
processor = HiDreamAttnProcessor_flashattn(),
|
||||||
|
single = True,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||||
|
if num_routed_experts > 0:
|
||||||
|
self.ff_i = MOEFeedForwardSwiGLU(
|
||||||
|
dim = dim,
|
||||||
|
hidden_dim = 4 * dim,
|
||||||
|
num_routed_experts = num_routed_experts,
|
||||||
|
num_activated_experts = num_activated_experts,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_tokens: torch.FloatTensor,
|
||||||
|
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||||
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
|
adaln_input: Optional[torch.FloatTensor] = None,
|
||||||
|
rope: torch.FloatTensor = None,
|
||||||
|
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
wtype = image_tokens.dtype
|
||||||
|
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||||
|
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
|
||||||
|
|
||||||
|
# 1. MM-Attention
|
||||||
|
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||||
|
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||||
|
attn_output_i = self.attn1(
|
||||||
|
norm_image_tokens,
|
||||||
|
image_tokens_masks,
|
||||||
|
rope = rope,
|
||||||
|
)
|
||||||
|
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||||
|
|
||||||
|
# 2. Feed-forward
|
||||||
|
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||||
|
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||||
|
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
|
||||||
|
image_tokens = ff_output_i + image_tokens
|
||||||
|
return image_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamImageTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_routed_experts: int = 4,
|
||||||
|
num_activated_experts: int = 2,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
# nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
|
# nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
# 1. Attention
|
||||||
|
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||||
|
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||||
|
self.attn1 = HiDreamAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
processor = HiDreamAttnProcessor_flashattn(),
|
||||||
|
single = False,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
||||||
|
if num_routed_experts > 0:
|
||||||
|
self.ff_i = MOEFeedForwardSwiGLU(
|
||||||
|
dim = dim,
|
||||||
|
hidden_dim = 4 * dim,
|
||||||
|
num_routed_experts = num_routed_experts,
|
||||||
|
num_activated_experts = num_activated_experts,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
|
||||||
|
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_tokens: torch.FloatTensor,
|
||||||
|
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||||
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
|
adaln_input: Optional[torch.FloatTensor] = None,
|
||||||
|
rope: torch.FloatTensor = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
wtype = image_tokens.dtype
|
||||||
|
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||||
|
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
|
||||||
|
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
|
||||||
|
|
||||||
|
# 1. MM-Attention
|
||||||
|
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
||||||
|
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
||||||
|
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
|
||||||
|
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
|
||||||
|
|
||||||
|
attn_output_i, attn_output_t = self.attn1(
|
||||||
|
norm_image_tokens,
|
||||||
|
image_tokens_masks,
|
||||||
|
norm_text_tokens,
|
||||||
|
rope = rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||||
|
text_tokens = gate_msa_t * attn_output_t + text_tokens
|
||||||
|
|
||||||
|
# 2. Feed-forward
|
||||||
|
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
||||||
|
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
||||||
|
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
|
||||||
|
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
|
||||||
|
|
||||||
|
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
|
||||||
|
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
|
||||||
|
image_tokens = ff_output_i + image_tokens
|
||||||
|
text_tokens = ff_output_t + text_tokens
|
||||||
|
return image_tokens, text_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamImageBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_routed_experts: int = 4,
|
||||||
|
num_activated_experts: int = 2,
|
||||||
|
block_type: BlockType = BlockType.TransformerBlock,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
block_classes = {
|
||||||
|
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
|
||||||
|
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
|
||||||
|
}
|
||||||
|
self.block = block_classes[block_type](
|
||||||
|
dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
num_routed_experts,
|
||||||
|
num_activated_experts,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_tokens: torch.FloatTensor,
|
||||||
|
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||||
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
|
adaln_input: torch.FloatTensor = None,
|
||||||
|
rope: torch.FloatTensor = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
return self.block(
|
||||||
|
image_tokens,
|
||||||
|
image_tokens_masks,
|
||||||
|
text_tokens,
|
||||||
|
adaln_input,
|
||||||
|
rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamImageTransformer2DModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: Optional[int] = None,
|
||||||
|
in_channels: int = 64,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
num_layers: int = 16,
|
||||||
|
num_single_layers: int = 32,
|
||||||
|
attention_head_dim: int = 128,
|
||||||
|
num_attention_heads: int = 20,
|
||||||
|
caption_channels: List[int] = None,
|
||||||
|
text_emb_dim: int = 2048,
|
||||||
|
num_routed_experts: int = 4,
|
||||||
|
num_activated_experts: int = 2,
|
||||||
|
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||||
|
max_resolution: Tuple[int, int] = (128, 128),
|
||||||
|
llama_layers: List[int] = None,
|
||||||
|
image_model=None,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_single_layers = num_single_layers
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.out_channels = out_channels or in_channels
|
||||||
|
self.inner_dim = self.num_attention_heads * self.attention_head_dim
|
||||||
|
self.llama_layers = llama_layers
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
patch_size = patch_size,
|
||||||
|
in_channels = in_channels,
|
||||||
|
out_channels = self.inner_dim,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
|
||||||
|
|
||||||
|
self.double_stream_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
HiDreamImageBlock(
|
||||||
|
dim = self.inner_dim,
|
||||||
|
num_attention_heads = self.num_attention_heads,
|
||||||
|
attention_head_dim = self.attention_head_dim,
|
||||||
|
num_routed_experts = num_routed_experts,
|
||||||
|
num_activated_experts = num_activated_experts,
|
||||||
|
block_type = BlockType.TransformerBlock,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_stream_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
HiDreamImageBlock(
|
||||||
|
dim = self.inner_dim,
|
||||||
|
num_attention_heads = self.num_attention_heads,
|
||||||
|
attention_head_dim = self.attention_head_dim,
|
||||||
|
num_routed_experts = num_routed_experts,
|
||||||
|
num_activated_experts = num_activated_experts,
|
||||||
|
block_type = BlockType.SingleTransformerBlock,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for i in range(self.num_single_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||||
|
caption_projection = []
|
||||||
|
for caption_channel in caption_channels:
|
||||||
|
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
|
||||||
|
self.caption_projection = nn.ModuleList(caption_projection)
|
||||||
|
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
||||||
|
|
||||||
|
def expand_timesteps(self, timesteps, batch_size, device):
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
is_mps = device.type == "mps"
|
||||||
|
if isinstance(timesteps, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(device)
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(batch_size)
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
|
||||||
|
x_arr = []
|
||||||
|
for i, img_size in enumerate(img_sizes):
|
||||||
|
pH, pW = img_size
|
||||||
|
x_arr.append(
|
||||||
|
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||||
|
p1=self.patch_size, p2=self.patch_size)
|
||||||
|
)
|
||||||
|
x = torch.cat(x_arr, dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def patchify(self, x, max_seq, img_sizes=None):
|
||||||
|
pz2 = self.patch_size * self.patch_size
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
B = x.shape[0]
|
||||||
|
device = x.device
|
||||||
|
dtype = x.dtype
|
||||||
|
else:
|
||||||
|
B = len(x)
|
||||||
|
device = x[0].device
|
||||||
|
dtype = x[0].dtype
|
||||||
|
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if img_sizes is not None:
|
||||||
|
for i, img_size in enumerate(img_sizes):
|
||||||
|
x_masks[i, 0:img_size[0] * img_size[1]] = 1
|
||||||
|
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
|
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
||||||
|
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
|
||||||
|
img_sizes = [[pH, pW]] * B
|
||||||
|
x_masks = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return x, x_masks, img_sizes
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states_llama3=None,
|
||||||
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
|
) -> torch.Tensor:
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
timesteps = t
|
||||||
|
pooled_embeds = y
|
||||||
|
T5_encoder_hidden_states = context
|
||||||
|
|
||||||
|
img_sizes = None
|
||||||
|
|
||||||
|
# spatial forward
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_states_type = hidden_states.dtype
|
||||||
|
|
||||||
|
# 0. time
|
||||||
|
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
|
||||||
|
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||||
|
p_embedder = self.p_embedder(pooled_embeds)
|
||||||
|
adaln_input = timesteps + p_embedder
|
||||||
|
|
||||||
|
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
||||||
|
if image_tokens_masks is None:
|
||||||
|
pH, pW = img_sizes[0]
|
||||||
|
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
|
||||||
|
# T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||||
|
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
|
||||||
|
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
|
||||||
|
|
||||||
|
if self.caption_projection is not None:
|
||||||
|
new_encoder_hidden_states = []
|
||||||
|
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
||||||
|
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
||||||
|
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||||
|
new_encoder_hidden_states.append(enc_hidden_state)
|
||||||
|
encoder_hidden_states = new_encoder_hidden_states
|
||||||
|
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
||||||
|
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||||
|
encoder_hidden_states.append(T5_encoder_hidden_states)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
|
||||||
|
3,
|
||||||
|
device=img_ids.device, dtype=img_ids.dtype
|
||||||
|
)
|
||||||
|
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||||
|
rope = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
block_id = 0
|
||||||
|
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||||
|
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||||
|
for bid, block in enumerate(self.double_stream_blocks):
|
||||||
|
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||||
|
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||||
|
hidden_states, initial_encoder_hidden_states = block(
|
||||||
|
image_tokens = hidden_states,
|
||||||
|
image_tokens_masks = image_tokens_masks,
|
||||||
|
text_tokens = cur_encoder_hidden_states,
|
||||||
|
adaln_input = adaln_input,
|
||||||
|
rope = rope,
|
||||||
|
)
|
||||||
|
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
image_tokens_seq_len = hidden_states.shape[1]
|
||||||
|
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||||
|
hidden_states_seq_len = hidden_states.shape[1]
|
||||||
|
if image_tokens_masks is not None:
|
||||||
|
encoder_attention_mask_ones = torch.ones(
|
||||||
|
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||||
|
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
|
||||||
|
)
|
||||||
|
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
||||||
|
|
||||||
|
for bid, block in enumerate(self.single_stream_blocks):
|
||||||
|
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||||
|
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||||
|
hidden_states = block(
|
||||||
|
image_tokens=hidden_states,
|
||||||
|
image_tokens_masks=image_tokens_masks,
|
||||||
|
text_tokens=None,
|
||||||
|
adaln_input=adaln_input,
|
||||||
|
rope=rope,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||||
|
output = self.final_layer(hidden_states, adaln_input)
|
||||||
|
output = self.unpatchify(output, img_sizes)
|
||||||
|
return -output[:, :, :h, :w]
|
@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context, context_img_len):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
context(Tensor): Shape [B, L2, C]
|
context(Tensor): Shape [B, L2, C]
|
||||||
"""
|
"""
|
||||||
context_img = context[:, :257]
|
context_img = context[:, :context_img_len]
|
||||||
context = context[:, 257:]
|
context = context[:, context_img_len:]
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
e,
|
e,
|
||||||
freqs,
|
freqs,
|
||||||
context,
|
context,
|
||||||
|
context_img_len=257,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x = x + y * e[2]
|
x = x + y * e[2]
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||||
x = x + y * e[5]
|
x = x + y * e[5]
|
||||||
return x
|
return x
|
||||||
@ -250,7 +251,7 @@ class Head(nn.Module):
|
|||||||
|
|
||||||
class MLPProj(torch.nn.Module):
|
class MLPProj(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, operation_settings={}):
|
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
@ -258,7 +259,15 @@ class MLPProj(torch.nn.Module):
|
|||||||
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||||
|
|
||||||
|
if flf_pos_embed_token_number is not None:
|
||||||
|
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||||
|
else:
|
||||||
|
self.emb_pos = None
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
|
if self.emb_pos is not None:
|
||||||
|
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
|
||||||
|
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
@ -284,6 +293,7 @@ class WanModel(torch.nn.Module):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -373,7 +383,7 @@ class WanModel(torch.nn.Module):
|
|||||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||||
|
|
||||||
if model_type == 'i2v':
|
if model_type == 'i2v':
|
||||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
|
||||||
else:
|
else:
|
||||||
self.img_emb = None
|
self.img_emb = None
|
||||||
|
|
||||||
@ -420,9 +430,12 @@ class WanModel(torch.nn.Module):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
if clip_fea is not None and self.img_emb is not None:
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -430,12 +443,12 @@ class WanModel(torch.nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model
|
|||||||
import comfy.ldm.lumina.model
|
import comfy.ldm.lumina.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
|
import comfy.ldm.hidream.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1056,3 +1057,20 @@ class Hunyuan3Dv2(BaseModel):
|
|||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HiDream(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||||
|
if conditioning_llama3 is not None:
|
||||||
|
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||||
|
return out
|
||||||
|
@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
else:
|
else:
|
||||||
dit_config["model_type"] = "t2v"
|
dit_config["model_type"] = "t2v"
|
||||||
|
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||||
|
if flf_weight is not None:
|
||||||
|
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
@ -338,6 +341,25 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "hidream"
|
||||||
|
dit_config["attention_head_dim"] = 128
|
||||||
|
dit_config["axes_dims_rope"] = [64, 32, 32]
|
||||||
|
dit_config["caption_channels"] = [4096, 4096]
|
||||||
|
dit_config["max_resolution"] = [128, 128]
|
||||||
|
dit_config["in_channels"] = 16
|
||||||
|
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
|
||||||
|
dit_config["num_attention_heads"] = 20
|
||||||
|
dit_config["num_routed_experts"] = 4
|
||||||
|
dit_config["num_activated_experts"] = 2
|
||||||
|
dit_config["num_layers"] = 16
|
||||||
|
dit_config["num_single_layers"] = 32
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["patch_size"] = 2
|
||||||
|
dit_config["text_emb_dim"] = 2048
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
23
comfy/ops.py
23
comfy/ops.py
@ -21,6 +21,7 @@ import logging
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
import comfy.rmsnorm
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
@ -146,6 +147,25 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.bias = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
if self.weight is not None:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
else:
|
||||||
|
weight = None
|
||||||
|
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -243,6 +263,9 @@ class manual_cast(disable_weight_init):
|
|||||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class RMSNorm(disable_weight_init.RMSNorm):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
class Embedding(disable_weight_init.Embedding):
|
class Embedding(disable_weight_init.Embedding):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
55
comfy/rmsnorm.py
Normal file
55
comfy/rmsnorm.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
if RMSNorm is None:
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape,
|
||||||
|
eps=None,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return rms_norm(x, self.weight, self.eps)
|
38
comfy/sd.py
38
comfy/sd.py
@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video
|
|||||||
import comfy.text_encoders.cosmos
|
import comfy.text_encoders.cosmos
|
||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
|
import comfy.text_encoders.hidream
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -702,6 +703,7 @@ class CLIPType(Enum):
|
|||||||
COSMOS = 11
|
COSMOS = 11
|
||||||
LUMINA2 = 12
|
LUMINA2 = 12
|
||||||
WAN = 13
|
WAN = 13
|
||||||
|
HIDREAM = 14
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -790,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.SD3:
|
elif clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@ -810,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
@ -826,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
@ -847,12 +864,33 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
# Detect
|
||||||
|
hidream_dualclip_classes = []
|
||||||
|
for hidream_te in clip_data:
|
||||||
|
te_model = detect_te_model(hidream_te)
|
||||||
|
hidream_dualclip_classes.append(te_model)
|
||||||
|
|
||||||
|
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
|
||||||
|
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
|
||||||
|
t5 = TEModel.T5_XXL in hidream_dualclip_classes
|
||||||
|
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
|
||||||
|
|
||||||
|
# Initialize t5xxl_detect and llama_detect kwargs if needed
|
||||||
|
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
|
||||||
|
llama_kwargs = llama_detect(clip_data) if llama else {}
|
||||||
|
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif len(clip_data) == 3:
|
elif len(clip_data) == 3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif len(clip_data) == 4:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
|
@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
"pooled",
|
"pooled",
|
||||||
"hidden"
|
"hidden",
|
||||||
|
"all"
|
||||||
]
|
]
|
||||||
def __init__(self, device="cpu", max_length=77,
|
def __init__(self, device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
|
if "model_name" not in model_options:
|
||||||
|
model_options = {**model_options, "model_name": "clip_l"}
|
||||||
|
|
||||||
if isinstance(textmodel_json_config, dict):
|
if isinstance(textmodel_json_config, dict):
|
||||||
config = textmodel_json_config
|
config = textmodel_json_config
|
||||||
@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with open(textmodel_json_config) as f:
|
with open(textmodel_json_config) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
|
||||||
|
for k, v in te_model_options.items():
|
||||||
|
config[k] = v
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
|
||||||
@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
layer_idx = options.get("layer", self.layer_idx)
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
if self.layer == "all":
|
||||||
|
pass
|
||||||
|
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
self.layer = "last"
|
self.layer = "last"
|
||||||
else:
|
else:
|
||||||
self.layer = "hidden"
|
self.layer = "hidden"
|
||||||
@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask_model = attention_mask
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
if self.layer == "all":
|
||||||
|
intermediate_output = "all"
|
||||||
|
else:
|
||||||
|
intermediate_output = self.layer_idx
|
||||||
|
|
||||||
|
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs[0].float()
|
z = outputs[0].float()
|
||||||
@ -447,7 +461,7 @@ class SDTokenizer:
|
|||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
self.max_length = max_length
|
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.end_token = None
|
self.end_token = None
|
||||||
|
|
||||||
@ -645,6 +659,7 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||||
|
model_options = {**model_options, "model_name": self.clip}
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
|
@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
|||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
|
model_options = {**model_options, "model_name": "clip_g"}
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class SDXLTokenizer:
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -41,8 +41,7 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
|||||||
|
|
||||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
|
model_options = {**model_options, "model_name": "clip_g"}
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
|
@ -1025,6 +1025,36 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
|||||||
|
|
||||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
class HiDream(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hidream",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
}
|
||||||
|
|
||||||
|
# memory_usage_factor = 1.2 # TODO
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HiDream(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None # TODO
|
||||||
|
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
|
|||||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
|||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
@ -9,14 +9,13 @@ import os
|
|||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class FluxTokenizer:
|
class FluxTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
|
||||||
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
|
|||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
155
comfy/text_encoders/hidream.py
Normal file
155
comfy/text_encoders/hidream.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from . import hunyuan_video
|
||||||
|
from . import sd3_clip
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from comfy import sdxl_clip
|
||||||
|
import comfy.model_management
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamTokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = {}
|
||||||
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||||
|
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamTEModel(torch.nn.Module):
|
||||||
|
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
self.dtypes = set()
|
||||||
|
if clip_l:
|
||||||
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
else:
|
||||||
|
self.clip_l = None
|
||||||
|
|
||||||
|
if clip_g:
|
||||||
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
else:
|
||||||
|
self.clip_g = None
|
||||||
|
|
||||||
|
if t5:
|
||||||
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
|
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
|
||||||
|
self.dtypes.add(dtype_t5)
|
||||||
|
else:
|
||||||
|
self.t5xxl = None
|
||||||
|
|
||||||
|
if llama:
|
||||||
|
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||||
|
if "vocab_size" not in model_options:
|
||||||
|
model_options["vocab_size"] = 128256
|
||||||
|
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
|
||||||
|
self.dtypes.add(dtype_llama)
|
||||||
|
else:
|
||||||
|
self.llama = None
|
||||||
|
|
||||||
|
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
if self.clip_l is not None:
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
if self.clip_g is not None:
|
||||||
|
self.clip_g.set_clip_options(options)
|
||||||
|
if self.t5xxl is not None:
|
||||||
|
self.t5xxl.set_clip_options(options)
|
||||||
|
if self.llama is not None:
|
||||||
|
self.llama.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
if self.clip_l is not None:
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
if self.clip_g is not None:
|
||||||
|
self.clip_g.reset_clip_options()
|
||||||
|
if self.t5xxl is not None:
|
||||||
|
self.t5xxl.reset_clip_options()
|
||||||
|
if self.llama is not None:
|
||||||
|
self.llama.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
|
token_weight_pairs_g = token_weight_pairs["g"]
|
||||||
|
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
||||||
|
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||||
|
lg_out = None
|
||||||
|
pooled = None
|
||||||
|
extra = {}
|
||||||
|
|
||||||
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
|
if self.clip_l is not None:
|
||||||
|
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
|
else:
|
||||||
|
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if self.clip_g is not None:
|
||||||
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
|
else:
|
||||||
|
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
|
if self.t5xxl is not None:
|
||||||
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
else:
|
||||||
|
t5_out = None
|
||||||
|
|
||||||
|
if self.llama is not None:
|
||||||
|
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
|
ll_out, ll_pooled = ll_output[:2]
|
||||||
|
ll_out = ll_out[:, 1:]
|
||||||
|
else:
|
||||||
|
ll_out = None
|
||||||
|
|
||||||
|
if t5_out is None:
|
||||||
|
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if ll_out is None:
|
||||||
|
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if pooled is None:
|
||||||
|
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
extra["conditioning_llama3"] = ll_out
|
||||||
|
return t5_out, pooled, extra
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
|
return self.t5xxl.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
|
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||||
|
class HiDreamTEModel_(HiDreamTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||||
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return HiDreamTEModel_
|
@ -21,26 +21,31 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
|
|
||||||
|
|
||||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class LLAMAModel(sd1_clip.SDClipModel):
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_scaled_fp8 is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
textmodel_json_config = {}
|
||||||
|
vocab_size = model_options.get("vocab_size", None)
|
||||||
|
if vocab_size is not None:
|
||||||
|
textmodel_json_config["vocab_size"] = vocab_size
|
||||||
|
|
||||||
|
model_options = {**model_options, "model_name": "llama"}
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoTokenizer:
|
class HunyuanVideoTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
|
||||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
|
||||||
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_llama])
|
self.dtypes = set([dtype, dtype_llama])
|
||||||
|
|
||||||
|
@ -9,24 +9,26 @@ import torch
|
|||||||
class HyditBertModel(sd1_clip.SDClipModel):
|
class HyditBertModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||||
|
model_options = {**model_options, "model_name": "hydit_clip"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class MT5XLModel(sd1_clip.SDClipModel):
|
class MT5XLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||||
|
model_options = {**model_options, "model_name": "mt5xl"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
@ -35,7 +37,7 @@ class HyditTokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
||||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
@ -268,11 +268,17 @@ class Llama2_(nn.Module):
|
|||||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
|
all_intermediate = None
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
if intermediate_output < 0:
|
if intermediate_output == "all":
|
||||||
|
all_intermediate = []
|
||||||
|
intermediate_output = None
|
||||||
|
elif intermediate_output < 0:
|
||||||
intermediate_output = len(self.layers) + intermediate_output
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
if all_intermediate is not None:
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
x = layer(
|
x = layer(
|
||||||
x=x,
|
x=x,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
@ -283,6 +289,12 @@ class Llama2_(nn.Module):
|
|||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
if all_intermediate is not None:
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
|
|
||||||
|
if all_intermediate is not None:
|
||||||
|
intermediate = torch.cat(all_intermediate, dim=1)
|
||||||
|
|
||||||
if intermediate is not None and final_layer_norm_intermediate:
|
if intermediate is not None and final_layer_norm_intermediate:
|
||||||
intermediate = self.norm(intermediate)
|
intermediate = self.norm(intermediate)
|
||||||
|
|
||||||
|
@ -1,30 +1,27 @@
|
|||||||
from comfy import sd1_clip
|
|
||||||
import os
|
|
||||||
|
|
||||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
||||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
|
||||||
|
|
||||||
class LongClipModel_(sd1_clip.SDClipModel):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
|
||||||
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
|
||||||
|
|
||||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
|
||||||
|
|
||||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
|
||||||
|
|
||||||
def model_options_long_clip(sd, tokenizer_data, model_options):
|
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||||
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is None:
|
||||||
|
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
else:
|
||||||
|
model_name = "clip_g"
|
||||||
|
|
||||||
if w is None:
|
if w is None:
|
||||||
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||||
if w is not None and w.shape[0] == 248:
|
if w is not None:
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
model_name = "clip_g"
|
||||||
|
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
model_name = "clip_l"
|
||||||
|
else:
|
||||||
|
model_name = "clip_l"
|
||||||
|
|
||||||
|
if w is not None:
|
||||||
tokenizer_data = tokenizer_data.copy()
|
tokenizer_data = tokenizer_data.copy()
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
model_config = model_options.get("model_config", {})
|
||||||
model_options["clip_l_class"] = LongClipModel_
|
model_config["max_position_embeddings"] = w.shape[0]
|
||||||
|
model_options["{}_model_config".format(model_name)] = model_config
|
||||||
|
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
|
||||||
return tokenizer_data, model_options
|
return tokenizer_data, model_options
|
||||||
|
@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
|
|||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
|
||||||
|
|
||||||
|
|
||||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
@ -6,7 +6,7 @@ import comfy.text_encoders.llama
|
|||||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
|||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||||
|
|
||||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
|
|||||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
model_options = {**model_options, "model_name": "t5xxl"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
class SPieceTokenizer:
|
class SPieceTokenizer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -15,6 +16,8 @@ class SPieceTokenizer:
|
|||||||
if isinstance(tokenizer_path, bytes):
|
if isinstance(tokenizer_path, bytes):
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
else:
|
else:
|
||||||
|
if not os.path.isfile(tokenizer_path):
|
||||||
|
raise ValueError("invalid tokenizer")
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
|
@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
|
|||||||
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import nodes
|
from __future__ import annotations
|
||||||
|
from typing import Type, Literal
|
||||||
|
|
||||||
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
@ -54,7 +57,22 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(class_def, input_name, valid_inputs=None):
|
def get_input_info(
|
||||||
|
class_def: Type[ComfyNodeABC],
|
||||||
|
input_name: str,
|
||||||
|
valid_inputs: InputTypeDict | None = None
|
||||||
|
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
||||||
|
"""Get the input type, category, and extra info for a given input name.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
class_def: The class definition of the node.
|
||||||
|
input_name: The name of the input to get info for.
|
||||||
|
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
|
||||||
|
"""
|
||||||
|
|
||||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
@ -126,7 +144,7 @@ class TopologicalSort:
|
|||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
continue
|
continue
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
|
100
comfy_extras/nodes_fresca.py
Normal file
100
comfy_extras/nodes_fresca.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
||||||
|
import torch
|
||||||
|
import torch.fft as fft
|
||||||
|
|
||||||
|
|
||||||
|
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
||||||
|
"""
|
||||||
|
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x: Input tensor of shape (B, C, H, W)
|
||||||
|
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
||||||
|
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
||||||
|
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
||||||
|
"""
|
||||||
|
# Preserve input dtype and device
|
||||||
|
dtype, device = x.dtype, x.device
|
||||||
|
|
||||||
|
# Convert to float32 for FFT computations
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
|
# 1) Apply FFT and shift low frequencies to center
|
||||||
|
x_freq = fft.fftn(x, dim=(-2, -1))
|
||||||
|
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
||||||
|
|
||||||
|
# Initialize mask with high-frequency scaling factor
|
||||||
|
mask = torch.ones(x_freq.shape, device=device) * scale_high
|
||||||
|
m = mask
|
||||||
|
for d in range(len(x_freq.shape) - 2):
|
||||||
|
dim = d + 2
|
||||||
|
cc = x_freq.shape[dim] // 2
|
||||||
|
f_c = min(freq_cutoff, cc)
|
||||||
|
m = m.narrow(dim, cc - f_c, f_c * 2)
|
||||||
|
|
||||||
|
# Apply low-frequency scaling factor to center region
|
||||||
|
m[:] = scale_low
|
||||||
|
|
||||||
|
# 3) Apply frequency-specific scaling
|
||||||
|
x_freq = x_freq * mask
|
||||||
|
|
||||||
|
# 4) Convert back to spatial domain
|
||||||
|
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
||||||
|
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||||
|
|
||||||
|
# 5) Restore original dtype
|
||||||
|
x_filtered = x_filtered.to(dtype)
|
||||||
|
|
||||||
|
return x_filtered
|
||||||
|
|
||||||
|
|
||||||
|
class FreSca:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
|
||||||
|
"tooltip": "Scaling factor for low-frequency components"}),
|
||||||
|
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
|
||||||
|
"tooltip": "Scaling factor for high-frequency components"}),
|
||||||
|
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
||||||
|
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
||||||
|
def custom_cfg_function(args):
|
||||||
|
cond = args["conds_out"][0]
|
||||||
|
uncond = args["conds_out"][1]
|
||||||
|
|
||||||
|
guidance = cond - uncond
|
||||||
|
filtered_guidance = Fourier_filter(
|
||||||
|
guidance,
|
||||||
|
scale_low=scale_low,
|
||||||
|
scale_high=scale_high,
|
||||||
|
freq_cutoff=freq_cutoff,
|
||||||
|
)
|
||||||
|
filtered_cond = filtered_guidance + uncond
|
||||||
|
|
||||||
|
return [filtered_cond, uncond]
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||||
|
|
||||||
|
return (m,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"FreSca": FreSca,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"FreSca": "FreSca",
|
||||||
|
}
|
32
comfy_extras/nodes_hidream.py
Normal file
32
comfy_extras/nodes_hidream.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import folder_paths
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class QuadrupleCLIPLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
|
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
|
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
|
||||||
|
|
||||||
|
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
|
||||||
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
|
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||||
|
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
return (clip,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||||
|
}
|
@ -21,8 +21,8 @@ class Load3D():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@ -41,7 +41,7 @@ class Load3D():
|
|||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, lineart_image
|
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||||
|
|
||||||
class Load3DAnimation():
|
class Load3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -59,8 +59,8 @@ class Load3DAnimation():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@ -77,13 +77,16 @@ class Load3DAnimation():
|
|||||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image
|
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||||
|
|
||||||
class Preview3D():
|
class Preview3D():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"camera_info": ("LOAD3D_CAMERA", {})
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@ -95,13 +98,22 @@ class Preview3D():
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
camera_info = kwargs.get("camera_info", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ui": {
|
||||||
|
"result": [model_file, camera_info]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class Preview3DAnimation():
|
class Preview3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"camera_info": ("LOAD3D_CAMERA", {})
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@ -113,7 +125,13 @@ class Preview3DAnimation():
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
camera_info = kwargs.get("camera_info", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ui": {
|
||||||
|
"result": [model_file, camera_info]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Load3D": Load3D,
|
"Load3D": Load3D,
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.clip_vision
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo:
|
||||||
@ -99,6 +100,72 @@ class WanFunControlToVideo:
|
|||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
class WanFirstLastFrameToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"start_image": ("IMAGE", ),
|
||||||
|
"end_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
image = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
image[:start_image.shape[0]] = start_image
|
||||||
|
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||||
|
|
||||||
|
if end_image is not None:
|
||||||
|
image[-end_image.shape[0]:] = end_image
|
||||||
|
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
if clip_vision_start_image is not None:
|
||||||
|
clip_vision_output = clip_vision_start_image
|
||||||
|
|
||||||
|
if clip_vision_end_image is not None:
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
|
||||||
|
clip_vision_output = comfy.clip_vision.Output()
|
||||||
|
clip_vision_output.penultimate_hidden_states = states
|
||||||
|
else:
|
||||||
|
clip_vision_output = clip_vision_end_image
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanFunInpaintToVideo:
|
class WanFunInpaintToVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -122,38 +189,13 @@ class WanFunInpaintToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
flfv = WanFirstLastFrameToVideo()
|
||||||
if start_image is not None:
|
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
if end_image is not None:
|
|
||||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
|
|
||||||
image = torch.ones((length, height, width, 3)) * 0.5
|
|
||||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
|
||||||
|
|
||||||
if start_image is not None:
|
|
||||||
image[:start_image.shape[0]] = start_image
|
|
||||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
|
||||||
|
|
||||||
if end_image is not None:
|
|
||||||
image[-end_image.shape[0]:] = end_image
|
|
||||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
|
||||||
|
|
||||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
|
||||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return (positive, negative, out_latent)
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
|
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.28"
|
__version__ = "0.3.29"
|
||||||
|
31
execution.py
31
execution.py
@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (type_input, extra_info)
|
info = (input_type, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = val["__value__"]
|
val = val["__value__"]
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if input_type == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "FLOAT":
|
if input_type == "FLOAT":
|
||||||
val = float(val)
|
val = float(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "STRING":
|
if input_type == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "BOOLEAN":
|
if input_type == "BOOLEAN":
|
||||||
val = bool(val)
|
val = bool(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
"message": f"Failed to convert an input value to a {type_input} value",
|
"message": f"Failed to convert an input value to a {input_type} value",
|
||||||
"details": f"{x}, {val}, {ex}",
|
"details": f"{x}, {val}, {ex}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(input_type, list):
|
||||||
if val not in type_input:
|
combo_options = input_type
|
||||||
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|
||||||
# Don't send back gigantic lists like if they're lots of
|
# Don't send back gigantic lists like if they're lots of
|
||||||
# scanned model filepaths
|
# scanned model filepaths
|
||||||
if len(type_input) > 20:
|
if len(combo_options) > 20:
|
||||||
list_info = f"(list of length {len(type_input)})"
|
list_info = f"(list of length {len(combo_options)})"
|
||||||
input_config = None
|
input_config = None
|
||||||
else:
|
else:
|
||||||
list_info = str(type_input)
|
list_info = str(combo_options)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "value_not_in_list",
|
"type": "value_not_in_list",
|
||||||
|
43
nodes.py
43
nodes.py
@ -917,7 +917,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -927,29 +927,10 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
if type == "stable_cascade":
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
|
||||||
elif type == "sd3":
|
|
||||||
clip_type = comfy.sd.CLIPType.SD3
|
|
||||||
elif type == "stable_audio":
|
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
|
||||||
elif type == "mochi":
|
|
||||||
clip_type = comfy.sd.CLIPType.MOCHI
|
|
||||||
elif type == "ltxv":
|
|
||||||
clip_type = comfy.sd.CLIPType.LTXV
|
|
||||||
elif type == "pixart":
|
|
||||||
clip_type = comfy.sd.CLIPType.PIXART
|
|
||||||
elif type == "cosmos":
|
|
||||||
clip_type = comfy.sd.CLIPType.COSMOS
|
|
||||||
elif type == "lumina2":
|
|
||||||
clip_type = comfy.sd.CLIPType.LUMINA2
|
|
||||||
elif type == "wan":
|
|
||||||
clip_type = comfy.sd.CLIPType.WAN
|
|
||||||
else:
|
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@ -964,7 +945,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -974,19 +955,13 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
|
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
if type == "sdxl":
|
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
|
||||||
elif type == "sd3":
|
|
||||||
clip_type = comfy.sd.CLIPType.SD3
|
|
||||||
elif type == "flux":
|
|
||||||
clip_type = comfy.sd.CLIPType.FLUX
|
|
||||||
elif type == "hunyuan_video":
|
|
||||||
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
|
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@ -2280,7 +2255,9 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_hunyuan3d.py",
|
"nodes_hunyuan3d.py",
|
||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
"nodes_cfg.py",
|
"nodes_cfg.py",
|
||||||
"nodes_optimalsteps.py"
|
"nodes_optimalsteps.py",
|
||||||
|
"nodes_hidream.py",
|
||||||
|
"nodes_fresca.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.28"
|
version = "0.3.29"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
comfyui-frontend-package==1.15.13
|
comfyui-frontend-package==1.16.9
|
||||||
|
comfyui-workflow-templates==0.1.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
@ -736,6 +736,12 @@ class PromptServer():
|
|||||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||||
|
|
||||||
|
workflow_templates_path = FrontendManager.templates_path()
|
||||||
|
if workflow_templates_path:
|
||||||
|
self.app.add_routes([
|
||||||
|
web.static('/templates', workflow_templates_path)
|
||||||
|
])
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
Loading…
Reference in New Issue
Block a user