mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add temporal tiling to VAE Decode (Tiled) node.
You can now do tiled VAE decoding on the temporal direction for videos.
This commit is contained in:
parent
f18ebbd316
commit
bc6dac4327
22
comfy/sd.py
22
comfy/sd.py
@ -259,6 +259,9 @@ class VAE:
|
|||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
self.downscale_index_formula = None
|
||||||
|
self.upscale_index_formula = None
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@ -338,6 +341,7 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.upscale_index_formula = (lambda a: max(0, a * 6), 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||||
self.working_dtypes = [torch.float16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||||
@ -353,6 +357,7 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||||
|
self.upscale_index_formula = (lambda a: max(0, a * 8), 32, 32)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
@ -360,6 +365,7 @@ class VAE:
|
|||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
ddconfig["time_compress"] = 4
|
ddconfig["time_compress"] = 4
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (lambda a: max(0, a * 4), 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
@ -426,7 +432,7 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@ -479,7 +485,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@ -497,6 +503,12 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
output = self.decode_tiled_(samples, **args)
|
output = self.decode_tiled_(samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (overlap_t, overlap, overlap)
|
||||||
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = tile_t
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
@ -575,6 +587,12 @@ class VAE:
|
|||||||
except:
|
except:
|
||||||
return self.downscale_ratio
|
return self.downscale_ratio
|
||||||
|
|
||||||
|
def temporal_compression_decode(self):
|
||||||
|
try:
|
||||||
|
return round(self.upscale_ratio[0](8192) / 8192)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -822,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|||||||
return rows * cols
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
|
|
||||||
if not (isinstance(upscale_amount, (tuple, list))):
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
@ -831,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
if not (isinstance(overlap, (tuple, list))):
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
overlap = [overlap] * dims
|
overlap = [overlap] * dims
|
||||||
|
|
||||||
|
if index_formulas is None:
|
||||||
|
index_formulas = upscale_amount
|
||||||
|
|
||||||
|
if not (isinstance(index_formulas, (tuple, list))):
|
||||||
|
index_formulas = [index_formulas] * dims
|
||||||
|
|
||||||
def get_upscale(dim, val):
|
def get_upscale(dim, val):
|
||||||
up = upscale_amount[dim]
|
up = upscale_amount[dim]
|
||||||
if callable(up):
|
if callable(up):
|
||||||
@ -845,10 +851,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
else:
|
else:
|
||||||
return val / up
|
return val / up
|
||||||
|
|
||||||
|
def get_upscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return up * val
|
||||||
|
|
||||||
|
def get_downscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return val / up
|
||||||
|
|
||||||
if downscale:
|
if downscale:
|
||||||
get_scale = get_downscale
|
get_scale = get_downscale
|
||||||
|
get_pos = get_downscale_pos
|
||||||
else:
|
else:
|
||||||
get_scale = get_upscale
|
get_scale = get_upscale
|
||||||
|
get_pos = get_upscale_pos
|
||||||
|
|
||||||
def mult_list_upscale(a):
|
def mult_list_upscale(a):
|
||||||
out = []
|
out = []
|
||||||
@ -881,7 +903,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(get_scale(d, pos)))
|
upscaled.append(round(get_pos(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
|
14
nodes.py
14
nodes.py
@ -293,17 +293,27 @@ class VAEDecodeTiled:
|
|||||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
||||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||||
|
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
|
||||||
|
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def decode(self, vae, samples, tile_size, overlap=64):
|
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
overlap = tile_size // 4
|
overlap = tile_size // 4
|
||||||
|
temporal_compression = vae.temporal_compression_decode()
|
||||||
|
if temporal_compression is not None:
|
||||||
|
temporal_size = max(2, temporal_size // temporal_compression)
|
||||||
|
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
|
||||||
|
else:
|
||||||
|
temporal_size = None
|
||||||
|
temporal_overlap = None
|
||||||
|
|
||||||
compression = vae.spacial_compression_decode()
|
compression = vae.spacial_compression_decode()
|
||||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
|
Loading…
Reference in New Issue
Block a user