Fix slow performance on 10 series Nvidia GPUs.

This commit is contained in:
comfyanonymous 2024-08-21 16:38:26 -04:00
parent 015f73dc49
commit a60620dcea

View File

@ -668,6 +668,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16