mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-06 11:32:09 +08:00
Make ACE VAE tiling work. (#8004)
This commit is contained in:
parent
5d3cc85e13
commit
a692c3cca4
39
comfy/sd.py
39
comfy/sd.py
@ -282,6 +282,7 @@ class VAE:
|
|||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
|
self.extra_1d_channel = None
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
@ -445,13 +446,14 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||||
self.latent_channels = 8
|
self.latent_channels = 8
|
||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
# self.upscale_ratio = 2048
|
self.upscale_ratio = 4096
|
||||||
# self.downscale_ratio = 2048
|
self.downscale_ratio = 4096
|
||||||
self.latent_dim = 2
|
self.latent_dim = 2
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = True
|
self.disable_offload = True
|
||||||
|
self.extra_1d_channel = 16
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -510,7 +512,13 @@ class VAE:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
if samples.ndim == 3:
|
||||||
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
|
else:
|
||||||
|
og_shape = samples.shape
|
||||||
|
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||||
|
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||||
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
@ -530,9 +538,24 @@ class VAE:
|
|||||||
samples /= 3.0
|
samples /= 3.0
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
if self.latent_dim == 1:
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
|
out_channels = self.latent_channels
|
||||||
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
|
else:
|
||||||
|
extra_channel_size = self.extra_1d_channel
|
||||||
|
out_channels = self.latent_channels * extra_channel_size
|
||||||
|
tile_x = tile_x // extra_channel_size
|
||||||
|
overlap = overlap // extra_channel_size
|
||||||
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||||
|
|
||||||
|
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||||
|
if self.latent_dim == 1:
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
@ -557,7 +580,7 @@ class VAE:
|
|||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
@ -624,7 +647,7 @@ class VAE:
|
|||||||
tile = 256
|
tile = 256
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
elif self.latent_dim == 1:
|
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||||
samples = self.encode_tiled_1d(pixel_samples)
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
else:
|
else:
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user