diff --git a/comfy/model_management.py b/comfy/model_management.py index 987b45e41..afbb133d4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -741,6 +741,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo return None fp16_supported = should_use_fp16(inference_device, prioritize_performance=True) + if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes: + return torch.float16 + for dt in supported_dtypes: if dt == torch.float16 and fp16_supported: return torch.float16