Use fp16 for intermediate for fp8 weights with --fast if supported.

This commit is contained in:
comfyanonymous 2025-02-28 02:17:50 -05:00
parent 1804397952
commit eb4543474b

View File

@ -741,6 +741,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16