diff --git a/comfy/sd.py b/comfy/sd.py index e85f2ed7..2db00fa4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -341,8 +341,9 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) - self.upscale_index_formula = (lambda a: max(0, a * 6), 8, 8) + self.upscale_index_formula = (6, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8) + self.downscale_index_formula = (6, 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"] @@ -357,16 +358,18 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) - self.upscale_index_formula = (lambda a: max(0, a * 8), 32, 32) + self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) + self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True ddconfig["time_compress"] = 4 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (lambda a: max(0, a * 4), 8, 8) + self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) @@ -453,7 +456,7 @@ class VAE: def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() - return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device) + return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in): pixel_samples = None @@ -544,7 +547,7 @@ class VAE: return samples - def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None): + def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) @@ -568,6 +571,12 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (overlap_t, overlap, overlap) + if tile_t is not None: + args["tile_t"] = tile_t samples = self.encode_tiled_3d(pixel_samples, **args) return samples diff --git a/nodes.py b/nodes.py index d6777df4..e95abc40 100644 --- a/nodes.py +++ b/nodes.py @@ -337,15 +337,17 @@ class VAEEncodeTiled: return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" - def encode(self, vae, pixels, tile_size, overlap): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap) - return ({"samples":t}, ) + def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): + t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + return ({"samples": t}, ) class VAEEncodeForInpaint: @classmethod