From b0aab1e4ea3dfefe09c4f07de0e5237558097e22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 18:36:29 -0500 Subject: [PATCH] Add an option --fp16-unet to force using fp16 for the unet. --- comfy/cli_args.py | 1 + comfy/model_management.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 58d034802..d9c8668f4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -57,6 +57,7 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpunet_group = parser.add_mutually_exclusive_group() fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") diff --git a/comfy/model_management.py b/comfy/model_management.py index fe0374a8b..b6a9471bf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -466,6 +466,8 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp16_unet: + return torch.float16 if args.fp8_e4m3fn_unet: return torch.float8_e4m3fn if args.fp8_e5m2_unet: