Make ACE VAE tiling work. (#8004)

This commit is contained in:
comfyanonymous 2025-05-08 04:25:45 -07:00 committed by GitHub
parent 5d3cc85e13
commit a692c3cca4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -282,6 +282,7 @@ class VAE:
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -445,13 +446,14 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
# self.upscale_ratio = 2048
# self.downscale_ratio = 2048
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -510,7 +512,13 @@ class VAE:
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
@ -530,9 +538,24 @@ class VAE:
samples /= 3.0
return samples
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
extra_channel_size = self.extra_1d_channel
out_channels = self.latent_channels * extra_channel_size
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
return out
else:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
@ -557,7 +580,7 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
@ -624,7 +647,7 @@ class VAE:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1:
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)