mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-07 03:52:10 +08:00
Make ACE VAE tiling work. (#8004)
This commit is contained in:
parent
5d3cc85e13
commit
a692c3cca4
39
comfy/sd.py
39
comfy/sd.py
@ -282,6 +282,7 @@ class VAE:
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@ -445,13 +446,14 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
# self.upscale_ratio = 2048
|
||||
# self.downscale_ratio = 2048
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -510,7 +512,13 @@ class VAE:
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
@ -530,9 +538,24 @@ class VAE:
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
extra_channel_size = self.extra_1d_channel
|
||||
out_channels = self.latent_channels * extra_channel_size
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
return out
|
||||
else:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
@ -557,7 +580,7 @@ class VAE:
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
@ -624,7 +647,7 @@ class VAE:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1:
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
Loading…
x
Reference in New Issue
Block a user