Better wan memory estimation.

This commit is contained in:
comfyanonymous 2025-02-26 07:51:22 -05:00
parent fa62287f1f
commit b6fefe686b

View File

@ -916,6 +916,10 @@ class WAN21_T2V(supported_models_base.BASE):
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device) out = model_base.WAN21(self, device=device)
return out return out