Add temporal tiling to VAE Decode (Tiled) node.

You can now do tiled VAE decoding on the temporal direction for videos.
This commit is contained in:
comfyanonymous 2024-12-23 20:03:37 -05:00
parent f18ebbd316
commit bc6dac4327
3 changed files with 56 additions and 6 deletions

View File

@ -259,6 +259,9 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_index_formula = None
self.upscale_index_formula = None
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@ -338,6 +341,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.upscale_index_formula = (lambda a: max(0, a * 6), 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
self.working_dtypes = [torch.float16, torch.float32] self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
@ -353,6 +357,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (lambda a: max(0, a * 8), 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd: elif "decoder.conv_in.conv.weight" in sd:
@ -360,6 +365,7 @@ class VAE:
ddconfig["conv3d"] = True ddconfig["conv3d"] = True
ddconfig["time_compress"] = 4 ddconfig["time_compress"] = 4
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (lambda a: max(0, a * 4), 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
@ -426,7 +432,7 @@ class VAE:
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
@ -479,7 +485,7 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2 dims = samples.ndim - 2
@ -497,6 +503,12 @@ class VAE:
elif dims == 2: elif dims == 2:
output = self.decode_tiled_(samples, **args) output = self.decode_tiled_(samples, **args)
elif dims == 3: elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (overlap_t, overlap, overlap)
if tile_t is not None:
args["tile_t"] = tile_t
output = self.decode_tiled_3d(samples, **args) output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1) return output.movedim(1, -1)
@ -575,6 +587,12 @@ class VAE:
except: except:
return self.downscale_ratio return self.downscale_ratio
def temporal_compression_decode(self):
try:
return round(self.upscale_ratio[0](8192) / 8192)
except:
return None
class StyleModel: class StyleModel:
def __init__(self, model, device="cpu"): def __init__(self, model, device="cpu"):
self.model = model self.model = model

View File

@ -822,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return rows * cols return rows * cols
@torch.inference_mode() @torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None): def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
dims = len(tile) dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))): if not (isinstance(upscale_amount, (tuple, list))):
@ -831,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if not (isinstance(overlap, (tuple, list))): if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val): def get_upscale(dim, val):
up = upscale_amount[dim] up = upscale_amount[dim]
if callable(up): if callable(up):
@ -845,10 +851,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
else: else:
return val / up return val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return val / up
if downscale: if downscale:
get_scale = get_downscale get_scale = get_downscale
get_pos = get_downscale_pos
else: else:
get_scale = get_upscale get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a): def mult_list_upscale(a):
out = [] out = []
@ -881,7 +903,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos) l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l) s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_scale(d, pos))) upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)

View File

@ -293,17 +293,27 @@ class VAEDecodeTiled:
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size, overlap=64): def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4: if tile_size < overlap * 4:
overlap = tile_size // 4 overlap = tile_size // 4
temporal_compression = vae.temporal_compression_decode()
if temporal_compression is not None:
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
else:
temporal_size = None
temporal_overlap = None
compression = vae.spacial_compression_decode() compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
if len(images.shape) == 5: #Combine batches if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, ) return (images, )