Add option for using fp8_e8m0fnu for model weights. (#7733)

Seems to break every model I have tried but worth testing?
This commit is contained in:
comfyanonymous 2025-04-22 03:17:38 -07:00 committed by GitHub
parent a8f63c0d5b
commit 2d6805ce57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 0 deletions

View File

@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")

View File

@ -725,6 +725,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if args.fp8_e8m0fnu_unet:
return torch.float8_e8m0fnu
fp8_dtype = None
if weight_dtype in FLOAT8_TYPES: