Make supported_dtypes a priority list.

This commit is contained in:
comfyanonymous 2024-08-07 15:00:06 -04:00
parent cb7c4b4be3
commit 6969fc9ba4

View File

@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
# None means no manual cast
@ -583,12 +593,12 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None
if fp16_supported and torch.float16 in supported_dtypes:
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
elif bf16_supported and torch.bfloat16 in supported_dtypes:
if dt == torch.bfloat16 and bf16_supported:
return torch.bfloat16
else:
return torch.float32
def text_encoder_offload_device():