diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index c7843d402..c7ef93ce1 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -54,7 +54,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z cd ComfyUI_windows_portable_nightly_pytorch diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 260a51bb2..bef1868b9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -46,6 +46,10 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") +fpvae_group = parser.add_mutually_exclusive_group() +fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") + parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") class LatentPreviewMethod(enum.Enum): diff --git a/comfy/model_management.py b/comfy/model_management.py index a918a81f6..09dcaa295 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -366,6 +366,14 @@ def vae_offload_device(): else: return torch.device("cpu") +def vae_dtype(): + if args.fp16_vae: + return torch.float16 + elif args.bf16_vae: + return torch.bfloat16 + else: + return torch.float32 + def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type diff --git a/comfy/sd.py b/comfy/sd.py index 7e64536c1..76eaa5b59 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -505,6 +505,8 @@ class VAE: device = model_management.vae_device() self.device = device self.offload_device = model_management.vae_offload_device() + self.vae_dtype = model_management.vae_dtype() + self.first_stage_model.to(self.vae_dtype) def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) @@ -512,7 +514,7 @@ class VAE: steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0) + decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + @@ -526,7 +528,7 @@ class VAE: steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() + encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float() samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -543,8 +545,8 @@ class VAE: pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() + samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -570,8 +572,8 @@ class VAE: batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() + pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")