mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 14:21:20 +00:00
Add some command line arguments to store text encoder weights in fp8.
Pytorch supports two variants of fp8: --fp8_e4m3fn-text-enc (the one that seems to give better results) --fp8_e5m2-text-enc
This commit is contained in:
parent
107e78b1cb
commit
0cf4e86939
@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||
|
||||
fpte_group = parser.add_mutually_exclusive_group()
|
||||
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
|
@ -482,6 +482,21 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
elif args.fp8_e5m2_text_enc:
|
||||
return torch.float8_e5m2
|
||||
elif args.fp16_text_enc:
|
||||
return torch.float16
|
||||
elif args.fp32_text_enc:
|
||||
return torch.float32
|
||||
|
||||
if should_use_fp16(device, prioritize_performance=False):
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
return get_torch_device()
|
||||
|
||||
|
@ -95,10 +95,7 @@ class CLIP:
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
||||
params['dtype'] = torch.float16
|
||||
else:
|
||||
params['dtype'] = torch.float32
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user