import math from typing import List, Optional, Tuple import torch import torch.nn as nn from torch import Tensor from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch class ControlNetEmbedder(nn.Module): def __init__( self, img_size: int, patch_size: int, in_chans: int, attention_head_dim: int, num_attention_heads: int, adm_in_channels: int, num_layers: int, main_model_double: int, double_y_emb: bool, device: torch.device, dtype: torch.dtype, pos_embed_max_size: Optional[int] = None, operations = None, ): super().__init__() self.main_model_double = main_model_double self.dtype = dtype self.hidden_size = num_attention_heads * attention_head_dim self.patch_size = patch_size self.x_embedder = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=self.hidden_size, strict_img_size=pos_embed_max_size is None, device=device, dtype=dtype, operations=operations, ) self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations) self.double_y_emb = double_y_emb if self.double_y_emb: self.orig_y_embedder = VectorEmbedder( adm_in_channels, self.hidden_size, dtype, device, operations=operations ) self.y_embedder = VectorEmbedder( self.hidden_size, self.hidden_size, dtype, device, operations=operations ) else: self.y_embedder = VectorEmbedder( adm_in_channels, self.hidden_size, dtype, device, operations=operations ) self.transformer_blocks = nn.ModuleList( DismantledBlock( hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, dtype=dtype, device=device, operations=operations ) for _ in range(num_layers) ) # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features # TODO double check this logic when 8b self.use_y_embedder = True self.controlnet_blocks = nn.ModuleList([]) for _ in range(len(self.transformer_blocks)): controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) self.controlnet_blocks.append(controlnet_block) self.pos_embed_input = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=self.hidden_size, strict_img_size=False, device=device, dtype=dtype, operations=operations, ) def forward( self, x: torch.Tensor, timesteps: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, hint = None, ) -> Tuple[Tensor, List[Tensor]]: x_shape = list(x.shape) x = self.x_embedder(x) if not self.double_y_emb: h = (x_shape[-2] + 1) // self.patch_size w = (x_shape[-1] + 1) // self.patch_size x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device) c = self.t_embedder(timesteps, dtype=x.dtype) if y is not None and self.y_embedder is not None: if self.double_y_emb: y = self.orig_y_embedder(y) y = self.y_embedder(y) c = c + y x = x + self.pos_embed_input(hint) block_out = () repeat = math.ceil(self.main_model_double / len(self.transformer_blocks)) for i in range(len(self.transformer_blocks)): out = self.transformer_blocks[i](x, c) if not self.double_y_emb: x = out block_out += (self.controlnet_blocks[i](out),) * repeat return {"output": block_out}