diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e28bd1382..a8212c1fa 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -916,6 +916,10 @@ class WAN21_T2V(supported_models_base.BASE): vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000 + def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, device=device) return out