From 6969fc9ba457067dbf61d478256c7dbe9adc4f61 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 15:00:06 -0400 Subject: [PATCH] Make supported_dtypes a priority list. --- comfy/model_management.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 994fcd83..ec80afea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if model_params * 2 > free_model_memory: return fp8_dtype - if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - if torch.float16 in supported_dtypes: - return torch.float16 - if should_use_bf16(device, model_params=model_params, manual_cast=True): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 + for dt in supported_dtypes: + if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params): + if torch.float16 in supported_dtypes: + return torch.float16 + if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + + for dt in supported_dtypes: + if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if torch.float16 in supported_dtypes: + return torch.float16 + if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + return torch.float32 # None means no manual cast @@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo if bf16_supported and weight_dtype == torch.bfloat16: return None - if fp16_supported and torch.float16 in supported_dtypes: - return torch.float16 + for dt in supported_dtypes: + if dt == torch.float16 and fp16_supported: + return torch.float16 + if dt == torch.bfloat16 and bf16_supported: + return torch.bfloat16 - elif bf16_supported and torch.bfloat16 in supported_dtypes: - return torch.bfloat16 - else: - return torch.float32 + return torch.float32 def text_encoder_offload_device(): if args.gpu_only: