fp8 casting is fast on GPUs that support fp8 compute.

This commit is contained in:
comfyanonymous 2024-10-20 00:54:47 -04:00
parent a68bbafddb
commit 471cd3eace

View File

@ -647,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
pass
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
return fp8_dtype
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype