diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 622c1df54..556c39512 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -456,3 +456,13 @@ class Wan21(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + +class Hunyuan3Dv2(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 0.9990943042622529 + +class Hunyuan3Dv2mini(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 1.0188137142395404 diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py new file mode 100644 index 000000000..4e18358f0 --- /dev/null +++ b/comfy/ldm/hunyuan3d/model.py @@ -0,0 +1,135 @@ +import torch +from torch import nn +from comfy.ldm.flux.layers import ( + DoubleStreamBlock, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +class Hunyuan3Dv2(nn.Module): + def __init__( + self, + in_channels=64, + context_in_dim=1536, + hidden_size=1024, + mlp_ratio=4.0, + num_heads=16, + depth=16, + depth_single_blocks=32, + qkv_bias=True, + guidance_embed=False, + image_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + + self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead + self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None + ) + self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device) + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth) + ] + ) + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth_single_blocks) + ] + ) + self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + x = x.movedim(-1, -2) + timestep = 1.0 - timestep + txt = context + img = self.latent_in(x) + + vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype)) + if self.guidance_in is not None: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype)) + + txt = self.cond_in(txt) + pe = None + attn_mask = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + img = img[:, txt.shape[1]:, ...] + img = self.final_layer(img, vec) + return img.movedim(-2, -1) * (-1.0) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py new file mode 100644 index 000000000..311c9b416 --- /dev/null +++ b/comfy/ldm/hunyuan3d/vae.py @@ -0,0 +1,587 @@ +# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py +# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from typing import Union, Tuple, List, Callable, Optional + +import numpy as np +from einops import repeat, rearrange +from tqdm import tqdm +import logging + +import comfy.ops +ops = comfy.ops.disable_weight_init + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij" + ) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + disable=not enable_pbar): + chunk_queries = xyz_samples[start: start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = F.scaled_dot_product_attention(q, k, v) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class MLP(nn.Module): + def __init__( + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.c_fc = ops.Linear(width, width * expand_ratio) + self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) + self.gelu = nn.GELU() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = ops.Linear(width, width, bias=qkv_bias) + self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logging.info('Save kv cache,this should be called only once for one mesh') + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = ops.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm + ) + + if self.enable_ln_post: + self.ln_post = ops.LayerNorm(width) + self.output_proj = ops.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ + + +class ShapeVAE(nn.Module): + def __init__( + self, + *, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + self.post_kl = ops.Linear(embed_dim, width) + + self.transformer = Transformer( + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.volume_decoder = VanillaVolumeDecoder() + self.scale_factor = scale_factor + + def decode(self, latents, **kwargs): + latents = self.post_kl(latents.movedim(-2, -1)) + latents = self.transformer(latents) + + bounds = kwargs.get("bounds", 1.01) + num_chunks = kwargs.get("num_chunks", 8000) + octree_resolution = kwargs.get("octree_resolution", 256) + enable_pbar = kwargs.get("enable_pbar", True) + + grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) + return grid_logits + + def encode(self, x): + return None diff --git a/comfy/model_base.py b/comfy/model_base.py index 976702b60..f02406ace 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -36,6 +36,7 @@ import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.hunyuan3d.model import comfy.model_management import comfy.patcher_extension @@ -1013,3 +1014,18 @@ class WAN21(BaseModel): if clip_vision_output is not None: out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class Hunyuan3Dv2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 403da5855..f9e96ab7e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "t2v" return dit_config + if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D + in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + dit_config = {} + dit_config["image_model"] = "hunyuan3d2" + dit_config["in_channels"] = in_shape[1] + dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') + dit_config["qkv_bias"] = True + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index 3d72a04d6..4160fa893 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.hunyuan3d.vae import yaml import math @@ -412,6 +413,16 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 + ln_post = "geo_decoder.ln_post.weight" in sd + inner_size = sd["geo_decoder.output_proj.weight"].shape[1] + downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size + mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) + ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -498,7 +509,7 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def decode(self, samples_in): + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None try: @@ -510,7 +521,7 @@ class VAE: for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b4d7bfe20..b5c3194cf 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -959,6 +959,42 @@ class WAN21_I2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +class Hunyuan3Dv2(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan3d2", + } + + unet_extra_config = {} + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + clip_vision_prefix = "conditioner.main_image_encoder.model." + vae_key_prefix = ["vae."] + + latent_format = latent_formats.Hunyuan3Dv2 + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + +class Hunyuan3Dv2mini(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2", + "depth": 8, + } + + latent_format = latent_formats.Hunyuan3Dv2mini + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py new file mode 100644 index 000000000..6abcde1f6 --- /dev/null +++ b/comfy_extras/nodes_hunyuan3d.py @@ -0,0 +1,410 @@ +import torch +import os +import json +import struct +import numpy as np +from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch +import folder_paths +import comfy.model_management +from comfy.cli_args import args + + +class EmptyLatentHunyuan3Dv2: + @classmethod + def INPUT_TYPES(s): + return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/3d" + + def generate(self, resolution, batch_size): + latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) + return ({"samples": latent, "type": "hunyuan3dv2"}, ) + + +class Hunyuan3Dv2Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision_output): + embeds = clip_vision_output.last_hidden_state + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class Hunyuan3Dv2ConditioningMultiView: + @classmethod + def INPUT_TYPES(s): + return {"required": {}, + "optional": {"front": ("CLIP_VISION_OUTPUT",), + "left": ("CLIP_VISION_OUTPUT",), + "back": ("CLIP_VISION_OUTPUT",), + "right": ("CLIP_VISION_OUTPUT",), }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, front=None, left=None, back=None, right=None): + all_embeds = [front, left, back, right] + out = [] + pos_embeds = None + for i, e in enumerate(all_embeds): + if e is not None: + if pos_embeds is None: + pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4)) + out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1)) + + embeds = torch.cat(out, dim=1) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class VOXEL: + def __init__(self, data): + self.data = data + + +class VAEDecodeHunyuan3D: + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT", ), + "vae": ("VAE", ), + "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), + "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), + }} + RETURN_TYPES = ("VOXEL",) + FUNCTION = "decode" + + CATEGORY = "latent/3d" + + def decode(self, vae, samples, num_chunks, octree_resolution): + voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return (voxels, ) + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 1]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 3], face_indices[:, 2]], dim=1)) + + vertex_count += 4 * num_faces + + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + return vertices, faces + + +class MESH: + def __init__(self, vertices, faces): + self.vertices = vertices + self.faces = faces + + +class VoxelToMeshBasic: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, threshold): + vertices = [] + faces = [] + for x in voxel.data: + v, f = voxel_to_mesh(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack('