diff --git a/comfy/sd.py b/comfy/sd.py index 4160fa893..d096f496c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -419,10 +419,11 @@ class VAE: inner_size = sd["geo_decoder.output_proj.weight"].shape[1] downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size - self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO + self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b5c3194cf..be3aede60 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -971,6 +971,8 @@ class Hunyuan3Dv2(supported_models_base.BASE): "shift": 1.0, } + memory_usage_factor = 3.5 + clip_vision_prefix = "conditioner.main_image_encoder.model." vae_key_prefix = ["vae."] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index ac2cff3a9..1ca7c2fe6 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -190,8 +190,12 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): vertex_count += 4 * num_faces - vertices = torch.cat(all_vertices, dim=0) - faces = torch.cat(all_indices, dim=0) + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) v_min = 0 v_max = max(voxels.shape)