From 6e8cdcd3cb542ba9eb5a5e5a420eff06f59dd268 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 18:00:34 -0500 Subject: [PATCH] Fix some tiled VAE decoding issues with LTX-Video. --- comfy/sd.py | 12 ++++++++++-- nodes.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b07b5fe3..e2af7078 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -269,7 +269,7 @@ class VAE: self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) - self.upscale_ratio = 8 + self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") @@ -370,7 +370,9 @@ class VAE: elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: - pixel_samples = self.decode_tiled_3d(samples_in) + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -434,6 +436,12 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() + def spacial_compression_decode(self): + try: + return self.upscale_ratio[-1] + except: + return self.upscale_ratio + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/nodes.py b/nodes.py index 01af6c68..3a68d43c 100644 --- a/nodes.py +++ b/nodes.py @@ -301,7 +301,8 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64): if tile_size < overlap * 4: overlap = tile_size // 4 - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) + compression = vae.spacial_compression_decode() + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, )