From bc6dac4327a838f8583f6272cc3cc612b9b16134 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 23 Dec 2024 20:03:37 -0500 Subject: [PATCH] Add temporal tiling to VAE Decode (Tiled) node. You can now do tiled VAE decoding on the temporal direction for videos. --- comfy/sd.py | 22 ++++++++++++++++++++-- comfy/utils.py | 26 ++++++++++++++++++++++++-- nodes.py | 14 ++++++++++++-- 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f79eacc2..e85f2ed7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -259,6 +259,9 @@ class VAE: self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] + self.downscale_index_formula = None + self.upscale_index_formula = None + if config is None: if "decoder.mid.block_1.mix_factor" in sd: encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -338,6 +341,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) + self.upscale_index_formula = (lambda a: max(0, a * 6), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv @@ -353,6 +357,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) + self.upscale_index_formula = (lambda a: max(0, a * 8), 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: @@ -360,6 +365,7 @@ class VAE: ddconfig["conv3d"] = True ddconfig["time_compress"] = 4 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (lambda a: max(0, a * 4), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] @@ -426,7 +432,7 @@ class VAE: def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() - return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) @@ -479,7 +485,7 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None): + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used) dims = samples.ndim - 2 @@ -497,6 +503,12 @@ class VAE: elif dims == 2: output = self.decode_tiled_(samples, **args) elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (overlap_t, overlap, overlap) + if tile_t is not None: + args["tile_t"] = tile_t output = self.decode_tiled_3d(samples, **args) return output.movedim(1, -1) @@ -575,6 +587,12 @@ class VAE: except: return self.downscale_ratio + def temporal_compression_decode(self): + try: + return round(self.upscale_ratio[0](8192) / 8192) + except: + return None + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy/utils.py b/comfy/utils.py index 5fb5418b..7de65933 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -822,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return rows * cols @torch.inference_mode() -def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None): +def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): dims = len(tile) if not (isinstance(upscale_amount, (tuple, list))): @@ -831,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if not (isinstance(overlap, (tuple, list))): overlap = [overlap] * dims + if index_formulas is None: + index_formulas = upscale_amount + + if not (isinstance(index_formulas, (tuple, list))): + index_formulas = [index_formulas] * dims + def get_upscale(dim, val): up = upscale_amount[dim] if callable(up): @@ -845,10 +851,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am else: return val / up + def get_upscale_pos(dim, val): + up = index_formulas[dim] + if callable(up): + return up(val) + else: + return up * val + + def get_downscale_pos(dim, val): + up = index_formulas[dim] + if callable(up): + return up(val) + else: + return val / up + if downscale: get_scale = get_downscale + get_pos = get_downscale_pos else: get_scale = get_upscale + get_pos = get_upscale_pos def mult_list_upscale(a): out = [] @@ -881,7 +903,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_scale(d, pos))) + upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) diff --git a/nodes.py b/nodes.py index bdea7564..d6777df4 100644 --- a/nodes.py +++ b/nodes.py @@ -293,17 +293,27 @@ class VAEDecodeTiled: return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "_for_testing" - def decode(self, vae, samples, tile_size, overlap=64): + def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 + temporal_compression = vae.temporal_compression_decode() + if temporal_compression is not None: + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression) + else: + temporal_size = None + temporal_overlap = None + compression = vae.spacial_compression_decode() - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, )