Compare commits

..

23 Commits

Author SHA1 Message Date
Chenlei Hu
f3b09b9f2d
[BugFix] Update frontend to 1.16.9 (#7655)
Backport https://github.com/Comfy-Org/ComfyUI_frontend/pull/3505
2025-04-18 15:12:42 -04:00
comfyanonymous
7ecd5e9614 Increase freq_cutoff in FreSca node. 2025-04-18 03:16:16 -04:00
City
2383a39e3b
Replace CLIPType if with getattr (#7589)
* Replace CLIPType if with getattr

* Forgot to remove breakpoint from testing
2025-04-18 02:53:36 -04:00
Terry Jia
34e06bf7ec
add support to output camera state (#7582) 2025-04-18 02:52:18 -04:00
Chenlei Hu
55822faa05
[Type] Annotate graph.get_input_info (#7386)
* [Type] Annotate graph.get_input_info

* nit

* nit
2025-04-17 21:02:24 -04:00
comfyanonymous
880c205df1 Add hidream to readme. 2025-04-17 16:58:27 -04:00
comfyanonymous
3dc240d089 Make fresca work on multi dim. 2025-04-17 15:46:41 -04:00
BVH
19373aee75
Add FreSca node (#7631) 2025-04-17 15:24:33 -04:00
comfyanonymous
93292bc450 ComfyUI version 0.3.29 2025-04-17 14:45:01 -04:00
Christian Byrne
05d5a75cdc
Update frontend to 1.16 (Install templates as pip package) (#7623)
* install templates as pip package

* Update requirements.txt

* bump templates version to include hidream

---------

Co-authored-by: Chenlei Hu <hcl@comfy.org>
2025-04-17 14:25:33 -04:00
comfyanonymous
eba7a25e7a Add WanFirstLastFrameToVideo node to use the new model. 2025-04-17 13:23:22 -04:00
comfyanonymous
dbcfd092a2 Set default context_img_len to 257 2025-04-17 12:42:34 -04:00
comfyanonymous
c14429940f Support loading WAN FLF model. 2025-04-17 12:04:48 -04:00
comfyanonymous
0d720e4367 Don't hardcode length of context_img in wan code. 2025-04-17 06:25:39 -04:00
comfyanonymous
1fc00ba4b6 Make hidream work with any latent resolution. 2025-04-16 18:34:14 -04:00
comfyanonymous
9899d187b1 Limit T5 to 128 tokens for HiDream: #7620 2025-04-16 18:07:55 -04:00
comfyanonymous
f00f340a56 Reuse code from flux model. 2025-04-16 17:43:55 -04:00
Chenlei Hu
cce1d9145e
[Type] Mark input options NotRequired (#7614) 2025-04-16 15:41:00 -04:00
comfyanonymous
b4dc03ad76 Fix issue on old torch. 2025-04-16 04:53:56 -04:00
comfyanonymous
9ad792f927 Basic support for hidream i1 model. 2025-04-15 17:35:05 -04:00
comfyanonymous
6fc5dbd52a Cleanup. 2025-04-15 12:13:28 -04:00
comfyanonymous
3e8155f7a3 More flexible long clip support.
Add clip g long clip support.

Text encoder refactor.

Support llama models with different vocab sizes.
2025-04-15 10:32:21 -04:00
comfyanonymous
8a438115fb add RMSNorm to comfy.ops 2025-04-14 18:00:33 -04:00
41 changed files with 1556 additions and 211 deletions

View File

@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)

View File

@ -184,6 +184,27 @@ comfyui-frontend-package is not installed.
)
sys.exit(-1)
@classmethod
def templates_path(cls) -> str:
try:
import comfyui_workflow_templates
return str(
importlib.resources.files(comfyui_workflow_templates) / "templates"
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""

View File

@ -99,59 +99,59 @@ class InputTypeOptions(TypedDict):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
"""
default: bool | str | float | int | list | tuple
default: NotRequired[bool | str | float | int | list | tuple]
"""The default value of the widget"""
defaultInput: bool
defaultInput: NotRequired[bool]
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: bool
forceInput: NotRequired[bool]
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: bool
lazy: NotRequired[bool]
"""Declares that this input uses lazy evaluation"""
rawLink: bool
rawLink: NotRequired[bool]
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: str
tooltip: NotRequired[str]
"""Tooltip for the input (or widget), shown on pointer hover"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: float
min: NotRequired[float]
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: float
max: NotRequired[float]
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: float
step: NotRequired[float]
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: float
round: NotRequired[float]
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: str
label_on: NotRequired[str]
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_off: str
label_off: NotRequired[str]
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: bool
multiline: NotRequired[bool]
"""Use a multiline text box (``STRING``)"""
placeholder: str
placeholder: NotRequired[str]
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: bool
dynamicPrompts: NotRequired[bool]
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
# class InputTypeCombo(InputTypeOptions):
image_upload: bool
image_upload: NotRequired[bool]
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
image_folder: Literal["input", "output", "temp"]
image_folder: NotRequired[Literal["input", "output", "temp"]]
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: RemoteInputOptions
remote: NotRequired[RemoteInputOptions]
"""Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: bool
control_after_generate: NotRequired[bool]
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
@ -169,15 +169,15 @@ class InputTypeOptions(TypedDict):
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: Literal["UNIQUE_ID"]
node_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: Literal["UNIQUE_ID"]
unique_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: Literal["PROMPT"]
prompt: NotRequired[Literal["PROMPT"]]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: Literal["EXTRA_PNGINFO"]
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: Literal["DYNPROMPT"]
dynprompt: NotRequired[Literal["DYNPROMPT"]]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
@ -187,11 +187,11 @@ class InputTypeDict(TypedDict):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes all inputs that must be connected for the node to execute."""
optional: dict[str, tuple[IO, InputTypeOptions]]
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes inputs which do not need to be connected."""
hidden: HiddenInputTypeDict
hidden: NotRequired[HiddenInputTypeDict]
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs

View File

@ -1,5 +1,6 @@
import torch
import comfy.ops
import comfy.rmsnorm
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
return torch.nn.functional.pad(img, pad, mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
rms_norm = comfy.rmsnorm.rms_norm

799
comfy/ldm/hidream/model.py Normal file
View File

@ -0,0 +1,799 @@
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import einops
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
import torch.nn.functional as F
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.ldm.common_dit
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
dtype=None, device=None, operations=None
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, latent):
latent = self.proj(latent)
return latent
class PooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, pooled_embed):
return self.pooled_embedder(pooled_embed)
class TimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, timesteps, wdtype):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
class HiDreamAttnProcessor_flashattn:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
dtype = image_tokens.dtype
batch_size = image_tokens.shape[0]
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
value_i = attn.to_v(image_tokens)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if image_tokens_masks is not None:
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
value_t = attn.to_v_t(text_tokens)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == rope.shape[-3] * 2:
query, key = apply_rope(query, key, rope)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, rope)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states
class HiDreamAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor = None,
out_dim: int = None,
single: bool = False,
dtype=None, device=None, operations=None
):
# super(Attention, self).__init__()
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
linear_cls = operations.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
if not single:
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.processor = processor
def forward(
self,
norm_image_tokens: torch.FloatTensor,
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.Tensor:
return self.processor(
self,
image_tokens = norm_image_tokens,
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
)
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
dtype=None, device=None, operations=None
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * (
(hidden_dim + multiple_of - 1) // multiple_of
)
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = 'softmax'
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
self.reset_parameters()
def reset_parameters(self) -> None:
pass
# import torch.nn.init as init
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
dtype=None, device=None, operations=None
):
super().__init__()
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
self.gate = MoEGate(
embed_dim = dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if True: # self.training: # TODO: check which branch performs faster
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
#y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
class BlockType:
TransformerBlock = 1
SingleTransformerBlock = 2
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = True,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_image_tokens,
image_tokens_masks,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
image_tokens = ff_output_i + image_tokens
return image_tokens
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
)
# nn.init.zeros_(self.adaLN_modulation[1].weight)
# nn.init.zeros_(self.adaLN_modulation[1].bias)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = False,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_image_tokens,
image_tokens_masks,
norm_text_tokens,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
text_tokens = gate_msa_t * attn_output_t + text_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
image_tokens = ff_output_i + image_tokens
text_tokens = ff_output_t + text_tokens
return image_tokens, text_tokens
class HiDreamImageBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
block_type: BlockType = BlockType.TransformerBlock,
dtype=None, device=None, operations=None
):
super().__init__()
block_classes = {
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
}
self.block = block_classes[block_type](
dim,
num_attention_heads,
attention_head_dim,
num_routed_experts,
num_activated_experts,
dtype=dtype, device=device, operations=operations
)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
return self.block(
image_tokens,
image_tokens_masks,
text_tokens,
adaln_input,
rope,
)
class HiDreamImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
image_model=None,
dtype=None, device=None, operations=None
):
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.num_layers = num_layers
self.num_single_layers = num_single_layers
self.gradient_checkpointing = False
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.inner_dim = self.num_attention_heads * self.attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.x_embedder = PatchEmbed(
patch_size = patch_size,
in_channels = in_channels,
out_channels = self.inner_dim,
dtype=dtype, device=device, operations=operations
)
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.TransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.SingleTransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_single_layers)
]
)
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.patch_size, p2=self.patch_size)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.patch_size * self.patch_size
if isinstance(x, torch.Tensor):
B = x.shape[0]
device = x.device
dtype = x.dtype
else:
B = len(x)
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0:img_size[0] * img_size[1]] = 1
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
elif isinstance(x, torch.Tensor):
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
control = None,
transformer_options = {},
) -> torch.Tensor:
bs, c, h, w = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
timesteps = t
pooled_embeds = y
T5_encoder_hidden_states = context
img_sizes = None
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
adaln_input = timesteps + p_embedder
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if image_tokens_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
hidden_states = self.x_embedder(hidden_states)
# T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
3,
device=img_ids.device, dtype=img_ids.dtype
)
ids = torch.cat((img_ids, txt_ids), dim=1)
rope = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
image_tokens_masks = image_tokens_masks,
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if image_tokens_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
)
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states = block(
image_tokens=hidden_states,
image_tokens_masks=image_tokens_masks,
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, adaln_input)
output = self.unpatchify(output, img_sizes)
return -output[:, :, :h, :w]

View File

@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
def forward(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context):
def forward(self, x, context, context_img_len):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :257]
context = context[:, 257:]
context_img = context[:, :context_img_len]
context = context[:, context_img_len:]
# compute query, key, value
q = self.norm_q(self.q(x))
@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
e,
freqs,
context,
context_img_len=257,
):
r"""
Args:
@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
@ -250,7 +251,7 @@ class Head(nn.Module):
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, operation_settings={}):
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
@ -258,7 +259,15 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
if flf_pos_embed_token_number is not None:
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
else:
self.emb_pos = None
def forward(self, image_embeds):
if self.emb_pos is not None:
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -284,6 +293,7 @@ class WanModel(torch.nn.Module):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
device=None,
dtype=None,
@ -373,7 +383,7 @@ class WanModel(torch.nn.Module):
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
else:
self.img_emb = None
@ -420,9 +430,12 @@ class WanModel(torch.nn.Module):
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@ -430,12 +443,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)

View File

@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.model_management
import comfy.patcher_extension
@ -1056,3 +1057,20 @@ class Hunyuan3Dv2(BaseModel):
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
if conditioning_llama3 is not None:
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
return out

View File

@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
@ -338,6 +341,25 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
dit_config["attention_head_dim"] = 128
dit_config["axes_dims_rope"] = [64, 32, 32]
dit_config["caption_channels"] = [4096, 4096]
dit_config["max_resolution"] = [128, 128]
dit_config["in_channels"] = 16
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
dit_config["num_attention_heads"] = 20
dit_config["num_routed_experts"] = 4
dit_config["num_activated_experts"] = 2
dit_config["num_layers"] = 16
dit_config["num_single_layers"] = 32
dit_config["out_channels"] = 16
dit_config["patch_size"] = 2
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@ -21,6 +21,7 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@ -146,6 +147,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -243,6 +263,9 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
class RMSNorm(disable_weight_init.RMSNorm):
comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True

55
comfy/rmsnorm.py Normal file
View File

@ -0,0 +1,55 @@
import torch
import comfy.model_management
import numbers
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=None,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)

View File

@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.model_patcher
import comfy.lora
@ -853,6 +854,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif len(clip_data) == 4:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
parameters = 0
for c in clip_data:

View File

@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
LAYERS = [
"last",
"pooled",
"hidden"
"hidden",
"all"
]
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if "model_name" not in model_options:
model_options = {**model_options, "model_name": "clip_l"}
if isinstance(textmodel_json_config, dict):
config = textmodel_json_config
@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
for k, v in te_model_options.items():
config[k] = v
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
if self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
if self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
if self.layer == "last":
z = outputs[0].float()
@ -447,7 +461,7 @@ class SDTokenizer:
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = max_length
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.end_token = None
@ -645,6 +659,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
model_options = {**model_options, "model_name": self.clip}
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()

View File

@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@ -41,8 +41,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)

View File

@ -1025,6 +1025,36 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
}
sampling_settings = {
"shift": 3.0,
}
sampling_settings = {
}
# memory_usage_factor = 1.2 # TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HiDream(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
models += [SVD_img2vid]

View File

@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -9,14 +9,13 @@ import os
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])

View File

@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -0,0 +1,151 @@
from . import hunyuan_video
from . import sd3_clip
from comfy import sd1_clip
from comfy import sdxl_clip
import comfy.model_management
import torch
import logging
class HiDreamTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HiDreamTEModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
if llama:
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
if "vocab_size" not in model_options:
model_options["vocab_size"] = 128256
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
self.dtypes.add(dtype_llama)
else:
self.llama = None
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
def set_clip_options(self, options):
if self.clip_l is not None:
self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options)
if self.llama is not None:
self.llama.set_clip_options(options)
def reset_clip_options(self):
if self.clip_l is not None:
self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options()
if self.llama is not None:
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_llama = token_weight_pairs["llama"]
lg_out = None
pooled = None
extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
else:
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
if self.llama is not None:
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
ll_out, ll_pooled = ll_output[:2]
ll_out = ll_out[:, 1:]
if t5_out is None:
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
if ll_out is None:
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
extra["conditioning_llama3"] = ll_out
return t5_out, pooled, extra
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
return self.t5xxl.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -21,26 +21,31 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
if vocab_size is not None:
textmodel_json_config["vocab_size"] = vocab_size
model_options = {**model_options, "model_name": "llama"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {}
@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])

View File

@ -9,24 +9,26 @@ import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
model_options = {**model_options, "model_name": "hydit_clip"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
model_options = {**model_options, "model_name": "mt5xl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@ -35,7 +37,7 @@ class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}

View File

@ -268,11 +268,17 @@ class Llama2_(nn.Module):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
all_intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@ -283,6 +289,12 @@ class Llama2_(nn.Module):
intermediate = x.clone()
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)

View File

@ -1,30 +1,27 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
else:
model_name = "clip_g"
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
if w is not None:
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
model_name = "clip_g"
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
model_name = "clip_l"
else:
model_name = "clip_l"
if w is not None:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
model_config = model_options.get("model_config", {})
model_config["max_position_embeddings"] = w.shape[0]
model_options["{}_model_config".format(model_name)] = model_config
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
return tokenizer_data, model_options

View File

@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -6,7 +6,7 @@ import comfy.text_encoders.llama
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""):
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module):
super().__init__()
self.dtypes = set()
if clip_l:
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None

View File

@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -1,6 +1,9 @@
import nodes
from __future__ import annotations
from typing import Type, Literal
import nodes
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception):
pass
@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name, valid_inputs=None):
def get_input_info(
class_def: Type[ComfyNodeABC],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
"""Get the input type, category, and extra info for a given input name.
Arguments:
class_def: The class definition of the node.
input_name: The name of the input to get info for.
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
Returns:
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
"""
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)

View File

@ -0,0 +1,100 @@
# Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch
import torch.fft as fft
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
"""
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
Parameters:
x: Input tensor of shape (B, C, H, W)
scale_low: Scaling factor for low-frequency components (default: 1.0)
scale_high: Scaling factor for high-frequency components (default: 1.5)
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
Returns:
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
"""
# Preserve input dtype and device
dtype, device = x.dtype, x.device
# Convert to float32 for FFT computations
x = x.to(torch.float32)
# 1) Apply FFT and shift low frequencies to center
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
# Initialize mask with high-frequency scaling factor
mask = torch.ones(x_freq.shape, device=device) * scale_high
m = mask
for d in range(len(x_freq.shape) - 2):
dim = d + 2
cc = x_freq.shape[dim] // 2
f_c = min(freq_cutoff, cc)
m = m.narrow(dim, cc - f_c, f_c * 2)
# Apply low-frequency scaling factor to center region
m[:] = scale_low
# 3) Apply frequency-specific scaling
x_freq = x_freq * mask
# 4) Convert back to spatial domain
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
# 5) Restore original dtype
x_filtered = x_filtered.to(dtype)
return x_filtered
class FreSca:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for low-frequency components"}),
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for high-frequency components"}),
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args):
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
guidance = cond - uncond
filtered_guidance = Fourier_filter(
guidance,
scale_low=scale_low,
scale_high=scale_high,
freq_cutoff=freq_cutoff,
)
filtered_cond = filtered_guidance + uncond
return [filtered_cond, uncond]
m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"FreSca": FreSca,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FreSca": "FreSca",
}

View File

@ -0,0 +1,32 @@
import folder_paths
import comfy.sd
import comfy.model_management
class QuadrupleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
}

View File

@ -21,8 +21,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
FUNCTION = "process"
EXPERIMENTAL = True
@ -41,7 +41,7 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
class Load3DAnimation():
@classmethod
@ -59,8 +59,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
FUNCTION = "process"
EXPERIMENTAL = True
@ -77,13 +77,16 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image
return output_image, output_mask, model_file, normal_image, image['camera_info']
class Preview3D():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
@ -95,13 +98,22 @@ class Preview3D():
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
return {"ui": {"model_file": [model_file]}, "result": ()}
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
@ -113,7 +125,13 @@ class Preview3DAnimation():
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
return {"ui": {"model_file": [model_file]}, "result": ()}
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,

View File

@ -4,6 +4,7 @@ import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
import comfy.clip_vision
class WanImageToVideo:
@ -99,6 +100,72 @@ class WanFunControlToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFirstLastFrameToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image
if clip_vision_end_image is not None:
if clip_vision_output is not None:
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
clip_vision_output = comfy.clip_vision.Output()
clip_vision_output.penultimate_hidden_states = states
else:
clip_vision_output = clip_vision_end_image
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
@ -122,38 +189,13 @@ class WanFunInpaintToVideo:
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
flfv = WanFirstLastFrameToVideo()
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.28"
__version__ = "0.3.29"

View File

@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
if x not in inputs:
if input_category == "required":
@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated):
continue
val = inputs[x]
info = (type_input, extra_info)
info = (input_type, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated):
val = val["__value__"]
inputs[x] = val
if type_input == "INT":
if input_type == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
if input_type == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
if input_type == "STRING":
val = str(val)
inputs[x] = val
if type_input == "BOOLEAN":
if input_type == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"message": f"Failed to convert an input value to a {input_type} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if isinstance(type_input, list):
if val not in type_input:
if isinstance(input_type, list):
combo_options = input_type
if val not in combo_options:
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
if len(combo_options) > 20:
list_info = f"(list of length {len(combo_options)})"
input_config = None
else:
list_info = str(type_input)
list_info = str(combo_options)
error = {
"type": "value_not_in_list",

View File

@ -930,26 +930,7 @@ class CLIPLoader:
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
elif type == "cosmos":
clip_type = comfy.sd.CLIPType.COSMOS
elif type == "lumina2":
clip_type = comfy.sd.CLIPType.LUMINA2
elif type == "wan":
clip_type = comfy.sd.CLIPType.WAN
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
model_options = {}
if device == "cpu":
@ -977,16 +958,10 @@ class DualCLIPLoader:
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
model_options = {}
if device == "cpu":
@ -2280,7 +2255,9 @@ def init_builtin_extra_nodes():
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
"nodes_optimalsteps.py"
"nodes_optimalsteps.py",
"nodes_hidream.py",
"nodes_fresca.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.28"
version = "0.3.29"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,4 +1,5 @@
comfyui-frontend-package==1.15.13
comfyui-frontend-package==1.16.9
comfyui-workflow-templates==0.1.1
torch
torchsde
torchvision

View File

@ -736,6 +736,12 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir)])
workflow_templates_path = FrontendManager.templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
self.app.add_routes([
web.static('/', self.web_root),
])