From 9d8b6c1f464e01a730bc4039cce483895dc888a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Jan 2025 03:48:40 -0500 Subject: [PATCH] More accurate memory estimation for cosmos and hunyuan video. --- comfy/sd.py | 4 ++-- comfy/supported_models.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6ba6af474..d7e89f726 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -388,8 +388,8 @@ class VAE: ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8} self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig) #TODO: these values are a bit off because this is not a standard VAE - self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ff3f14329..87fecde5a 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.HunyuanVideo - memory_usage_factor = 2.0 #TODO + memory_usage_factor = 1.7 #TODO supported_inference_dtypes = [torch.bfloat16, torch.float32] @@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Cosmos1CV8x8x8 - memory_usage_factor = 2.4 #TODO + memory_usage_factor = 1.6 #TODO supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO