From 24129d78e6ac5349389ca99349242a13cdedf1d2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Feb 2024 13:23:43 -0500 Subject: [PATCH] Speed up SDXL on 16xx series with fp16 weights and manual cast. --- comfy/model_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index cbaa8087..aa40c502 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0): return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params): + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): return torch.float16 return torch.float32 @@ -696,7 +696,7 @@ def is_device_mps(device): return True return False -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True