# Original code: https://github.com/VectorSpaceLab/OmniGen2 from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from comfy.ldm.lightricks.model import Timesteps from comfy.ldm.flux.layers import EmbedND from comfy.ldm.modules.attention import optimized_attention_masked import comfy.model_management import comfy.ldm.common_dit def apply_rotary_emb(x, freqs_cis): if x.shape[1] == 0: return x t_ = x.reshape(*x.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] return t_out.reshape(*x.shape).to(dtype=x.dtype) def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return F.silu(x) * y class TimestepEmbedding(nn.Module): def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None): super().__init__() self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device) self.act = nn.SiLU() self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device) def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.linear_1(sample) sample = self.act(sample) sample = self.linear_2(sample) return sample class LuminaRMSNormZero(nn.Module): def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None): super().__init__() self.silu = nn.SiLU() self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device) self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp class LuminaLayerNormContinuous(nn.Module): def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None): super().__init__() self.silu = nn.SiLU() self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device) self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device) self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) x = self.norm(x) * (1 + emb)[:, None, :] if self.linear_2 is not None: x = self.linear_2(x) return x class LuminaFeedForward(nn.Module): def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None): super().__init__() inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device) self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device) self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: h1, h2 = self.linear_1(x), self.linear_3(x) return self.linear_2(swiglu(h1, h2)) class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None): super().__init__() self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale) self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations) self.caption_embedder = nn.Sequential( operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device), operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device), ) def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]: timestep_proj = self.time_proj(timestep).to(dtype=dtype) time_embed = self.timestep_embedder(timestep_proj) caption_embed = self.caption_embedder(text_hidden_states) return time_embed, caption_embed class Attention(nn.Module): def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None): super().__init__() self.heads = heads self.kv_heads = kv_heads self.dim_head = dim_head self.scale = dim_head ** -0.5 self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device) self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device) self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device) self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) self.to_out = nn.Sequential( operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device), nn.Dropout(0.0) ) def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) query = query.view(batch_size, -1, self.heads, self.dim_head) key = key.view(batch_size, -1, self.kv_heads, self.dim_head) value = value.view(batch_size, -1, self.kv_heads, self.dim_head) query = self.norm_q(query) key = self.norm_k(key) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) if self.kv_heads < self.heads: key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) hidden_states = self.to_out[0](hidden_states) return hidden_states class OmniGen2TransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None): super().__init__() self.modulation = modulation self.attn = Attention( query_dim=dim, dim_head=dim // num_attention_heads, heads=num_attention_heads, kv_heads=num_kv_heads, eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations, ) self.feed_forward = LuminaFeedForward( dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations ) if modulation: self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) else: self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) return hidden_states class OmniGen2RotaryPosEmbed(nn.Module): def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2): super().__init__() self.theta = theta self.axes_dim = axes_dim self.axes_lens = axes_lens self.patch_size = patch_size self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim) def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device): p = self.patch_size seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] max_seq_len = max(seq_lengths) max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) max_img_len = max(l_effective_img_len) position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") pe_shift = cap_seq_len pe_shift_len = cap_seq_len if ref_img_sizes[i] is not None: for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): H, W = ref_img_size ref_H_tokens, ref_W_tokens = H // p, W // p row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids pe_shift += max(ref_H_tokens, ref_W_tokens) pe_shift_len += ref_img_len H, W = img_sizes[i] H_tokens, W_tokens = H // p, W // p row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() position_ids[i, pe_shift_len: seq_len, 0] = pe_shift position_ids[i, pe_shift_len: seq_len, 1] = row_ids position_ids[i, pe_shift_len: seq_len, 2] = col_ids freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) cap_freqs_cis_shape = list(freqs_cis.shape) cap_freqs_cis_shape[1] = encoder_seq_len cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) ref_img_freqs_cis_shape = list(freqs_cis.shape) ref_img_freqs_cis_shape[1] = max_ref_img_len ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) img_freqs_cis_shape = list(freqs_cis.shape) img_freqs_cis_shape[1] = max_img_len img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths class OmniGen2Transformer2DModel(nn.Module): def __init__( self, patch_size: int = 2, in_channels: int = 16, out_channels: Optional[int] = None, hidden_size: int = 2304, num_layers: int = 26, num_refiner_layers: int = 2, num_attention_heads: int = 24, num_kv_heads: int = 8, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, norm_eps: float = 1e-5, axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), axes_lens: Tuple[int, int, int] = (300, 512, 512), text_feat_dim: int = 1024, timestep_scale: float = 1.0, image_model=None, device=None, dtype=None, operations=None, ): super().__init__() self.patch_size = patch_size self.out_channels = out_channels or in_channels self.hidden_size = hidden_size self.dtype = dtype self.rope_embedder = OmniGen2RotaryPosEmbed( theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size, ) self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( hidden_size=hidden_size, text_feat_dim=text_feat_dim, norm_eps=norm_eps, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations ) self.noise_refiner = nn.ModuleList([ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations ) for _ in range(num_refiner_layers) ]) self.ref_image_refiner = nn.ModuleList([ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations ) for _ in range(num_refiner_layers) ]) self.context_refiner = nn.ModuleList([ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(num_refiner_layers) ]) self.layers = nn.ModuleList([ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations ) for _ in range(num_layers) ]) self.norm_out = LuminaLayerNormContinuous( embedding_dim=hidden_size, conditioning_embedding_dim=min(hidden_size, 1024), elementwise_affine=False, eps=1e-6, out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations ) self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype)) def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): batch_size = len(hidden_states) p = self.patch_size img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] if ref_image_hidden_states is not None: ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states)) ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] else: ref_img_sizes = [None for _ in range(batch_size)] l_effective_ref_img_len = [[0] for _ in range(batch_size)] flat_ref_img_hidden_states = None if ref_image_hidden_states is not None: imgs = [] for ref_img in ref_image_hidden_states: B, C, H, W = ref_img.size() ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) imgs.append(ref_img) flat_ref_img_hidden_states = torch.cat(imgs, dim=1) img = hidden_states B, C, H, W = img.size() flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) return ( flat_hidden_states, flat_ref_img_hidden_states, None, None, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): batch_size = len(hidden_states) hidden_states = self.x_embedder(hidden_states) if ref_image_hidden_states is not None: ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device) for i in range(batch_size): shift = 0 for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j] shift += ref_img_len for layer in self.noise_refiner: hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) if ref_image_hidden_states is not None: for layer in self.ref_image_refiner: ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) return hidden_states def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): B, C, H, W = x.shape hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) _, _, H_padded, W_padded = hidden_states.shape timestep = 1.0 - timesteps text_hidden_states = context text_attention_mask = attention_mask ref_image_hidden_states = ref_latents device = hidden_states.device temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) ( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) ( context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, rotary_emb, encoder_seq_lengths, seq_lengths, ) = self.rope_embedder( hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0], l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device, ) for layer in self.context_refiner: text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) img_len = hidden_states.shape[1] combined_img_hidden_states = self.img_patch_embed_and_refine( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, ) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) attention_mask = None for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) hidden_states = self.norm_out(hidden_states, temb) p = self.patch_size output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W] return -output