Manual cast for bf16 on older GPUs.

This commit is contained in:
comfyanonymous 2024-02-17 08:13:17 -05:00
parent 6c875d846b
commit 929e266f3e

View File

@ -499,7 +499,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device):
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
@ -771,10 +771,24 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True
def should_use_bf16(device=None):
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
return False
if FORCE_FP32:
return False
if directml_enabled:
return False
if cpu_mode() or mps_mode():
return False
if is_intel_xpu():
return True
@ -785,6 +799,13 @@ def should_use_bf16(device=None):
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
return False
def soft_empty_cache(force=False):