mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Make supported_dtypes a priority list.
This commit is contained in:
parent
cb7c4b4be3
commit
6969fc9ba4
@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and bf16_supported:
|
||||
return torch.bfloat16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
|
Loading…
Reference in New Issue
Block a user