diff --git a/comfy/model_management.py b/comfy/model_management.py index 054291432..816caf18f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1290,6 +1290,13 @@ def supports_fp8_compute(device=None): return True +def extended_fp16_support(): + # TODO: check why some models work with fp16 on newer torch versions but not on older + if torch_version_numeric < (2, 7): + return False + + return True + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index f4413d647..2669ca01e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1197,11 +1197,16 @@ class Omnigen2(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux - supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] + def __init__(self, unet_config): + super().__init__(unet_config) + if comfy.model_management.extended_fp16_support(): + self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes + def get_model(self, state_dict, prefix="", device=None): out = model_base.Omnigen2(self, device=device) return out