# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py # Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI. import torch import torch.nn as nn import torch.nn.functional as F from typing import Union, Tuple, List, Callable, Optional import numpy as np from einops import repeat, rearrange from tqdm import tqdm import logging import comfy.ops ops = comfy.ops.disable_weight_init def generate_dense_grid_points( bbox_min: np.ndarray, bbox_max: np.ndarray, octree_resolution: int, indexing: str = "ij", ): length = bbox_max - bbox_min num_cells = octree_resolution x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) xyz = np.stack((xs, ys, zs), axis=-1) grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] return xyz, grid_size, length class VanillaVolumeDecoder: @torch.no_grad() def __call__( self, latents: torch.FloatTensor, geo_decoder: Callable, bounds: Union[Tuple[float], List[float], float] = 1.01, num_chunks: int = 10000, octree_resolution: int = None, enable_pbar: bool = True, **kwargs, ): device = latents.device dtype = latents.dtype batch_size = latents.shape[0] # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) xyz_samples, grid_size, length = generate_dense_grid_points( bbox_min=bbox_min, bbox_max=bbox_max, octree_resolution=octree_resolution, indexing="ij" ) xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) # 2. latents to 3d volume batch_logits = [] for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", disable=not enable_pbar): chunk_queries = xyz_samples[start: start + num_chunks, :] chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) logits = geo_decoder(queries=chunk_queries, latents=latents) batch_logits.append(logits) grid_logits = torch.cat(batch_logits, dim=1) grid_logits = grid_logits.view((batch_size, *grid_size)).float() return grid_logits class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: [ sin(x[..., i]), sin(f_1*x[..., i]), sin(f_2*x[..., i]), ... sin(f_N * x[..., i]), cos(x[..., i]), cos(f_1*x[..., i]), cos(f_2*x[..., i]), ... cos(f_N * x[..., i]), x[..., i] # only present if include_input is True. ], here f_i is the frequency. Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. Args: num_freqs (int): the number of frequencies, default is 6; logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; input_dim (int): the input dimension, default is 3; include_input (bool): include the input tensor or not, default is True. Attributes: frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), otherwise, it is input_dim * num_freqs * 2. """ def __init__(self, num_freqs: int = 6, logspace: bool = True, input_dim: int = 3, include_input: bool = True, include_pi: bool = True) -> None: """The initialization""" super().__init__() if logspace: frequencies = 2.0 ** torch.arange( num_freqs, dtype=torch.float32 ) else: frequencies = torch.linspace( 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 ) if include_pi: frequencies *= torch.pi self.register_buffer("frequencies", frequencies, persistent=False) self.include_input = include_input self.num_freqs = num_freqs self.out_dim = self.get_dims(input_dim) def get_dims(self, input_dim): temp = 1 if self.include_input or self.num_freqs == 0 else 0 out_dim = input_dim * (self.num_freqs * 2 + temp) return out_dim def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward process. Args: x: tensor of shape [..., dim] Returns: embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] where temp is 1 if include_input is True and 0 otherwise. """ if self.num_freqs > 0: embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) if self.include_input: return torch.cat((x, embed.sin(), embed.cos()), dim=-1) else: return torch.cat((embed.sin(), embed.cos()), dim=-1) else: return x class CrossAttentionProcessor: def __call__(self, attn, q, k, v): out = F.scaled_dot_product_attention(q, k, v) return out class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if self.drop_prob == 0. or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and self.scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor def extra_repr(self): return f'drop_prob={round(self.drop_prob, 3):0.3f}' class MLP(nn.Module): def __init__( self, *, width: int, expand_ratio: int = 4, output_width: int = None, drop_path_rate: float = 0.0 ): super().__init__() self.width = width self.c_fc = ops.Linear(width, width * expand_ratio) self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) self.gelu = nn.GELU() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) class QKVMultiheadCrossAttention(nn.Module): def __init__( self, *, heads: int, width=None, qk_norm=False, norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.attn_processor = CrossAttentionProcessor() def forward(self, q, kv): _, n_ctx, _ = q.shape bs, n_data, width = kv.shape attn_ch = width // self.heads // 2 q = q.view(bs, n_ctx, self.heads, -1) kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) out = self.attn_processor(self, q, k, v) out = out.transpose(1, 2).reshape(bs, n_ctx, -1) return out class MultiheadCrossAttention(nn.Module): def __init__( self, *, width: int, heads: int, qkv_bias: bool = True, data_width: Optional[int] = None, norm_layer=ops.LayerNorm, qk_norm: bool = False, kv_cache: bool = False, ): super().__init__() self.width = width self.heads = heads self.data_width = width if data_width is None else data_width self.c_q = ops.Linear(width, width, bias=qkv_bias) self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) self.c_proj = ops.Linear(width, width) self.attention = QKVMultiheadCrossAttention( heads=heads, width=width, norm_layer=norm_layer, qk_norm=qk_norm ) self.kv_cache = kv_cache self.data = None def forward(self, x, data): x = self.c_q(x) if self.kv_cache: if self.data is None: self.data = self.c_kv(data) logging.info('Save kv cache,this should be called only once for one mesh') data = self.data else: data = self.c_kv(data) x = self.attention(x, data) x = self.c_proj(x) return x class ResidualCrossAttentionBlock(nn.Module): def __init__( self, *, width: int, heads: int, mlp_expand_ratio: int = 4, data_width: Optional[int] = None, qkv_bias: bool = True, norm_layer=ops.LayerNorm, qk_norm: bool = False ): super().__init__() if data_width is None: data_width = width self.attn = MultiheadCrossAttention( width=width, heads=heads, data_width=data_width, qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) def forward(self, x: torch.Tensor, data: torch.Tensor): x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x)) return x class QKVMultiheadAttention(nn.Module): def __init__( self, *, heads: int, width=None, qk_norm=False, norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() def forward(self, qkv): bs, n_ctx, width = qkv.shape attn_ch = width // self.heads // 3 qkv = qkv.view(bs, n_ctx, self.heads, -1) q, k, v = torch.split(qkv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out class MultiheadAttention(nn.Module): def __init__( self, *, width: int, heads: int, qkv_bias: bool, norm_layer=ops.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0 ): super().__init__() self.width = width self.heads = heads self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) self.c_proj = ops.Linear(width, width) self.attention = QKVMultiheadAttention( heads=heads, width=width, norm_layer=norm_layer, qk_norm=qk_norm ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward(self, x): x = self.c_qkv(x) x = self.attention(x) x = self.drop_path(self.c_proj(x)) return x class ResidualAttentionBlock(nn.Module): def __init__( self, *, width: int, heads: int, qkv_bias: bool = True, norm_layer=ops.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0, ): super().__init__() self.attn = MultiheadAttention( width=width, heads=heads, qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, drop_path_rate=drop_path_rate ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) def forward(self, x: torch.Tensor): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__( self, *, width: int, layers: int, heads: int, qkv_bias: bool = True, norm_layer=ops.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0 ): super().__init__() self.width = width self.layers = layers self.resblocks = nn.ModuleList( [ ResidualAttentionBlock( width=width, heads=heads, qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, drop_path_rate=drop_path_rate ) for _ in range(layers) ] ) def forward(self, x: torch.Tensor): for block in self.resblocks: x = block(x) return x class CrossAttentionDecoder(nn.Module): def __init__( self, *, out_channels: int, fourier_embedder: FourierEmbedder, width: int, heads: int, mlp_expand_ratio: int = 4, downsample_ratio: int = 1, enable_ln_post: bool = True, qkv_bias: bool = True, qk_norm: bool = False, label_type: str = "binary" ): super().__init__() self.enable_ln_post = enable_ln_post self.fourier_embedder = fourier_embedder self.downsample_ratio = downsample_ratio self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) if self.downsample_ratio != 1: self.latents_proj = ops.Linear(width * downsample_ratio, width) if self.enable_ln_post == False: qk_norm = False self.cross_attn_decoder = ResidualCrossAttentionBlock( width=width, mlp_expand_ratio=mlp_expand_ratio, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm ) if self.enable_ln_post: self.ln_post = ops.LayerNorm(width) self.output_proj = ops.Linear(width, out_channels) self.label_type = label_type self.count = 0 def forward(self, queries=None, query_embeddings=None, latents=None): if query_embeddings is None: query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) self.count += query_embeddings.shape[1] if self.downsample_ratio != 1: latents = self.latents_proj(latents) x = self.cross_attn_decoder(query_embeddings, latents) if self.enable_ln_post: x = self.ln_post(x) occ = self.output_proj(x) return occ class ShapeVAE(nn.Module): def __init__( self, *, embed_dim: int, width: int, heads: int, num_decoder_layers: int, geo_decoder_downsample_ratio: int = 1, geo_decoder_mlp_expand_ratio: int = 4, geo_decoder_ln_post: bool = True, num_freqs: int = 8, include_pi: bool = True, qkv_bias: bool = True, qk_norm: bool = False, label_type: str = "binary", drop_path_rate: float = 0.0, scale_factor: float = 1.0, ): super().__init__() self.geo_decoder_ln_post = geo_decoder_ln_post self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) self.post_kl = ops.Linear(embed_dim, width) self.transformer = Transformer( width=width, layers=num_decoder_layers, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm, drop_path_rate=drop_path_rate ) self.geo_decoder = CrossAttentionDecoder( fourier_embedder=self.fourier_embedder, out_channels=1, mlp_expand_ratio=geo_decoder_mlp_expand_ratio, downsample_ratio=geo_decoder_downsample_ratio, enable_ln_post=self.geo_decoder_ln_post, width=width // geo_decoder_downsample_ratio, heads=heads // geo_decoder_downsample_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, label_type=label_type, ) self.volume_decoder = VanillaVolumeDecoder() self.scale_factor = scale_factor def decode(self, latents, **kwargs): latents = self.post_kl(latents.movedim(-2, -1)) latents = self.transformer(latents) bounds = kwargs.get("bounds", 1.01) num_chunks = kwargs.get("num_chunks", 8000) octree_resolution = kwargs.get("octree_resolution", 256) enable_pbar = kwargs.get("enable_pbar", True) grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) return grid_logits.movedim(-2, -1) def encode(self, x): return None