mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-06 07:17:09 +08:00
Omnigen2 model implementation. (#8669)
This commit is contained in:
parent
7a13f74220
commit
ec70ed6aea
469
comfy/ldm/omnigen/omnigen2.py
Normal file
469
comfy/ldm/omnigen/omnigen2.py
Normal file
@ -0,0 +1,469 @@
|
||||
# Original code: https://github.com/VectorSpaceLab/OmniGen2
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x) * y
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class LuminaRMSNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class LuminaLayerNormContinuous(nn.Module):
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
x = self.norm(x) * (1 + emb)[:, None, :]
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h1, h2 = self.linear_1(x), self.linear_3(x)
|
||||
return self.linear_2(swiglu(h1, h2))
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
|
||||
self.caption_embedder = nn.Sequential(
|
||||
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
|
||||
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.kv_heads = kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
|
||||
nn.Dropout(0.0)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
query = query.view(batch_size, -1, self.heads, self.dim_head)
|
||||
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||
|
||||
query = self.norm_q(query)
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if self.kv_heads < self.heads:
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2TransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
|
||||
|
||||
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
|
||||
p = self.patch_size
|
||||
|
||||
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
||||
|
||||
max_seq_len = max(seq_lengths)
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
||||
|
||||
pe_shift = cap_seq_len
|
||||
pe_shift_len = cap_seq_len
|
||||
|
||||
if ref_img_sizes[i] is not None:
|
||||
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
||||
H, W = ref_img_size
|
||||
ref_H_tokens, ref_W_tokens = H // p, W // p
|
||||
|
||||
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
||||
|
||||
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
||||
pe_shift_len += ref_img_len
|
||||
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p, W // p
|
||||
|
||||
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
||||
|
||||
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
||||
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
cap_freqs_cis_shape[1] = encoder_seq_len
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
ref_img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
ref_img_freqs_cis_shape[1] = max_ref_img_len
|
||||
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
||||
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
||||
|
||||
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
|
||||
class OmniGen2Transformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
text_feat_dim: int = 1024,
|
||||
timestep_scale: float = 1.0,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
|
||||
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||
theta=10000,
|
||||
axes_dim=axes_dim_rope,
|
||||
axes_lens=axes_lens,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size,
|
||||
text_feat_dim=text_feat_dim,
|
||||
norm_eps=norm_eps,
|
||||
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.noise_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.ref_image_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.context_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
|
||||
|
||||
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
||||
batch_size = len(hidden_states)
|
||||
p = self.patch_size
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
|
||||
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
|
||||
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
||||
else:
|
||||
ref_img_sizes = [None for _ in range(batch_size)]
|
||||
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
||||
|
||||
flat_ref_img_hidden_states = None
|
||||
if ref_image_hidden_states is not None:
|
||||
imgs = []
|
||||
for ref_img in ref_image_hidden_states:
|
||||
B, C, H, W = ref_img.size()
|
||||
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
imgs.append(ref_img)
|
||||
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
|
||||
|
||||
img = hidden_states
|
||||
B, C, H, W = img.size()
|
||||
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
return (
|
||||
flat_hidden_states, flat_ref_img_hidden_states,
|
||||
None, None,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
)
|
||||
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
||||
batch_size = len(hidden_states)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
||||
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
for layer in self.ref_image_refiner:
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
||||
|
||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
_, _, H_padded, W_padded = hidden_states.shape
|
||||
timestep = 1.0 - timesteps
|
||||
text_hidden_states = context
|
||||
text_attention_mask = attention_mask
|
||||
ref_image_hidden_states = ref_latents
|
||||
device = hidden_states.device
|
||||
|
||||
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||
|
||||
(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||
|
||||
(
|
||||
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
|
||||
rotary_emb, encoder_seq_lengths, seq_lengths,
|
||||
) = self.rope_embedder(
|
||||
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes, device,
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
|
||||
img_len = hidden_states.shape[1]
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
noise_rotary_emb, ref_img_rotary_emb,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
temb,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||
attention_mask = None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
p = self.patch_size
|
||||
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
|
||||
|
||||
return -output
|
@ -41,6 +41,7 @@ import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1230,3 +1231,33 @@ class ACEStep(BaseModel):
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
@ -459,6 +459,26 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "omnigen2"
|
||||
dit_config["axes_dim_rope"] = [40, 40, 40]
|
||||
dit_config["axes_lens"] = [1024, 1664, 1664]
|
||||
dit_config["ffn_dim_multiplier"] = None
|
||||
dit_config["hidden_size"] = 2520
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["multiple_of"] = 256
|
||||
dit_config["norm_eps"] = 1e-05
|
||||
dit_config["num_attention_heads"] = 21
|
||||
dit_config["num_kv_heads"] = 7
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["num_refiner_layers"] = 2
|
||||
dit_config["out_channels"] = None
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["text_feat_dim"] = 2048
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
@ -44,6 +44,7 @@ import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -754,6 +755,7 @@ class CLIPType(Enum):
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -773,6 +775,7 @@ class TEModel(Enum):
|
||||
LLAMA3_8 = 7
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@ -793,6 +796,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
return TEModel.QWEN25_3B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@ -894,6 +899,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
|
@ -482,7 +482,8 @@ class SDTokenizer:
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
self.end_token = empty[0]
|
||||
if has_end_token:
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
|
@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -1181,6 +1182,36 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
class Omnigen2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "omnigen2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Omnigen2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@ -24,6 +24,24 @@ class Llama2Config:
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 11008
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 2
|
||||
max_position_embeddings: int = 128000
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -40,6 +58,7 @@ class Gemma2_2B_Config:
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@ -98,9 +117,9 @@ class Attention(nn.Module):
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_3BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
|
44
comfy/text_encoders/omnigen2.py
Normal file
44
comfy/text_encoders/omnigen2.py
Normal file
@ -0,0 +1,44 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
import os
|
||||
|
||||
|
||||
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
|
||||
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
|
||||
class Qwen25_3BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Omnigen2Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class Omnigen2TEModel_(Omnigen2Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Omnigen2TEModel_
|
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,241 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"151643": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151644": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151645": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151646": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151647": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151648": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151649": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151650": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151651": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151652": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151653": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151654": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151655": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151656": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151657": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151658": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151659": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151660": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151661": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151662": {
|
||||
"content": "<|fim_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151663": {
|
||||
"content": "<|repo_name|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151664": {
|
||||
"content": "<|file_sep|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<|img|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151666": {
|
||||
"content": "<|endofimg|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151667": {
|
||||
"content": "<|meta|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151668": {
|
||||
"content": "<|endofmeta|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"bos_token": null,
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"errors": "replace",
|
||||
"extra_special_tokens": {},
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"processor_class": "Qwen2_5_VLProcessor",
|
||||
"split_special_tokens": false,
|
||||
"tokenizer_class": "Qwen2Tokenizer",
|
||||
"unk_token": null
|
||||
}
|
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
File diff suppressed because one or more lines are too long
26
comfy_extras/nodes_edit_model.py
Normal file
26
comfy_extras/nodes_edit_model.py
Normal file
@ -0,0 +1,26 @@
|
||||
import node_helpers
|
||||
|
||||
|
||||
class ReferenceLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
},
|
||||
"optional": {"latent": ("LATENT", ),}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "advanced/conditioning/edit_models"
|
||||
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
|
||||
|
||||
def append(self, conditioning, latent=None):
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ReferenceLatent": ReferenceLatent,
|
||||
}
|
5
nodes.py
5
nodes.py
@ -920,7 +920,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -930,7 +930,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -2279,6 +2279,7 @@ def init_builtin_extra_nodes():
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user