diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 2a02acd65..6e8e06181 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -7,7 +7,7 @@ from einops import rearrange import math from typing import Dict, Optional, Tuple -from .symmetric_patchifier import SymmetricPatchifier +from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords def get_timestep_embedding( @@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module): positional_embedding_theta=10000.0, positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), dtype=None, device=None, operations=None, **kwargs): super().__init__() self.generator = None + self.vae_scale_factors = vae_scale_factors self.dtype = dtype self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) @@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) - indices_grid = self.patchifier.get_grid( - orig_num_frames=x.shape[2], - orig_height=x.shape[3], - orig_width=x.shape[4], - batch_size=x.shape[0], - scale_grid=((1 / frame_rate) * 8, 32, 32), - device=x.device, - ) - - if guiding_latent is not None: - ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) - input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) - ts *= input_ts - ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2) - timestep = self.patchifier.patchify(ts) - input_x = x.clone() - x[:, :, 0] = guiding_latent[:, :, 0] - if guiding_latent_noise_scale > 0: - if self.generator is None: - self.generator = torch.Generator(device=x.device).manual_seed(42) - elif self.generator.device != x.device: - self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state()) - - noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]] - scale = guiding_latent_noise_scale * (input_ts ** 2) - guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator) - - x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0]) - - orig_shape = list(x.shape) - x = self.patchifier.patchify(x) + x, latent_coords = self.patchifier.patchify(x) + pixel_coords = latent_to_pixel_coords( + latent_coords=latent_coords, + scale_factors=self.vae_scale_factors, + causal_fix=self.causal_temporal_positioning, + ) + + if keyframe_idxs is not None: + pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) x = self.patchify_proj(x) timestep = timestep * 1000.0 @@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module): if attention_mask is not None and not torch.is_floating_point(attention_mask): attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) + pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) batch_size = x.shape[0] timestep, embedded_timestep = self.adaln_single( @@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module): out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), ) - if guiding_latent is not None: - x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0] - - # print("res", x) return x diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index c58dfb20b..4b9972b9f 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -6,16 +6,29 @@ from einops import rearrange from torch import Tensor -def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - elif dims_to_append == 0: - return x - return x[(...,) + (None,) * dims_to_append] +def latent_to_pixel_coords( + latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False +) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space. + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + pixel_coords = ( + latent_coords + * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + ) + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords class Patchifier(ABC): @@ -44,29 +57,26 @@ class Patchifier(ABC): def patch_size(self): return self._patch_size - def get_grid( - self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device ): - f = orig_num_frames // self._patch_size[0] - h = orig_height // self._patch_size[1] - w = orig_width // self._patch_size[2] - grid_h = torch.arange(h, dtype=torch.float32, device=device) - grid_w = torch.arange(w, dtype=torch.float32, device=device) - grid_f = torch.arange(f, dtype=torch.float32, device=device) - grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij') - grid = torch.stack(grid, dim=0) - grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - - if scale_grid is not None: - for i in range(3): - if isinstance(scale_grid[i], Tensor): - scale = append_dims(scale_grid[i], grid.ndim - 1) - else: - scale = scale_grid[i] - grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] - - grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) - return grid + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords class SymmetricPatchifier(Patchifier): @@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier): self, latents: Tensor, ) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) latents = rearrange( latents, "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", @@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier): p2=self._patch_size[1], p3=self._patch_size[2], ) - return latents + return latents, latent_coords def unpatchify( self, diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index c572e7e86..70d612e86 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -15,6 +15,7 @@ class CausalConv3d(nn.Module): stride: Union[int, Tuple[int]] = 1, dilation: int = 1, groups: int = 1, + spatial_padding_mode: str = "zeros", **kwargs, ): super().__init__() @@ -38,7 +39,7 @@ class CausalConv3d(nn.Module): stride=stride, dilation=dilation, padding=padding, - padding_mode="zeros", + padding_mode=spatial_padding_mode, groups=groups, ) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index e0344deec..043ca0496 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,13 +1,15 @@ +from __future__ import annotations import torch from torch import nn from functools import partial import math from einops import rearrange -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops + ops = comfy.ops.disable_weight_init class Encoder(nn.Module): @@ -32,7 +34,7 @@ class Encoder(nn.Module): norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use. Can be either `group_norm` or `pixel_norm`. latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. """ def __init__( @@ -40,12 +42,13 @@ class Encoder(nn.Module): dims: Union[int, Tuple[int, int]] = 3, in_channels: int = 3, out_channels: int = 3, - blocks=[("res_x", 1)], + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], base_channels: int = 128, norm_num_groups: int = 32, patch_size: Union[int, Tuple[int]] = 1, norm_layer: str = "group_norm", # group_norm, pixel_norm latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size @@ -65,6 +68,7 @@ class Encoder(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.down_blocks = nn.ModuleList([]) @@ -82,6 +86,7 @@ class Encoder(nn.Module): resnet_eps=1e-6, resnet_groups=norm_num_groups, norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "res_x_y": output_channel = block_params.get("multiplier", 2) * output_channel @@ -92,6 +97,7 @@ class Encoder(nn.Module): eps=1e-6, groups=norm_num_groups, norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": block = make_conv_nd( @@ -101,6 +107,7 @@ class Encoder(nn.Module): kernel_size=3, stride=(2, 1, 1), causal=True, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": block = make_conv_nd( @@ -110,6 +117,7 @@ class Encoder(nn.Module): kernel_size=3, stride=(1, 2, 2), causal=True, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": block = make_conv_nd( @@ -119,6 +127,7 @@ class Encoder(nn.Module): kernel_size=3, stride=(2, 2, 2), causal=True, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all_x_y": output_channel = block_params.get("multiplier", 2) * output_channel @@ -129,6 +138,34 @@ class Encoder(nn.Module): kernel_size=3, stride=(2, 2, 2), causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"unknown block: {block_name}") @@ -152,10 +189,18 @@ class Encoder(nn.Module): conv_out_channels *= 2 elif latent_log_var == "uniform": conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 elif latent_log_var != "none": raise ValueError(f"Invalid latent_log_var: {latent_log_var}") self.conv_out = make_conv_nd( - dims, output_channel, conv_out_channels, 3, padding=1, causal=True + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False @@ -197,6 +242,15 @@ class Encoder(nn.Module): sample = torch.cat([sample, repeated_last_channel], dim=1) else: raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = ( + -30 + ) # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) return sample @@ -231,7 +285,7 @@ class Decoder(nn.Module): dims, in_channels: int = 3, out_channels: int = 3, - blocks=[("res_x", 1)], + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], base_channels: int = 128, layers_per_block: int = 2, norm_num_groups: int = 32, @@ -239,6 +293,7 @@ class Decoder(nn.Module): norm_layer: str = "group_norm", causal: bool = True, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size @@ -264,6 +319,7 @@ class Decoder(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.up_blocks = nn.ModuleList([]) @@ -283,6 +339,7 @@ class Decoder(nn.Module): norm_layer=norm_layer, inject_noise=block_params.get("inject_noise", False), timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "attn_res_x": block = UNetMidBlock3D( @@ -294,6 +351,7 @@ class Decoder(nn.Module): inject_noise=block_params.get("inject_noise", False), timestep_conditioning=timestep_conditioning, attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) @@ -306,14 +364,21 @@ class Decoder(nn.Module): norm_layer=norm_layer, inject_noise=block_params.get("inject_noise", False), timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": block = DepthToSpaceUpsample( - dims=dims, in_channels=input_channel, stride=(2, 1, 1) + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": block = DepthToSpaceUpsample( - dims=dims, in_channels=input_channel, stride=(1, 2, 2) + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": output_channel = output_channel // block_params.get("multiplier", 1) @@ -323,6 +388,7 @@ class Decoder(nn.Module): stride=(2, 2, 2), residual=block_params.get("residual", False), out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"unknown layer: {block_name}") @@ -340,7 +406,13 @@ class Decoder(nn.Module): self.conv_act = nn.SiLU() self.conv_out = make_conv_nd( - dims, output_channel, out_channels, 3, padding=1, causal=True + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False @@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module): resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. Returns: `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, @@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module): norm_layer: str = "group_norm", inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ): super().__init__() resnet_groups = ( @@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module): norm_layer=norm_layer, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) for _ in range(num_layers) ] ) def forward( - self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: timestep_embed = None if self.timestep_conditioning: @@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module): return hidden_states +class SpaceToDepthDownsample(nn.Module): + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat( + [x[:, :, :1, :, :], x], dim=2 + ) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + class DepthToSpaceUpsample(nn.Module): def __init__( - self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1 + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", ): super().__init__() self.stride = stride @@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module): kernel_size=3, stride=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.residual = residual self.out_channels_reduction_factor = out_channels_reduction_factor @@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module): norm_layer: str = "group_norm", inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ): super().__init__() self.in_channels = in_channels @@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) if inject_noise: @@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) if inject_noise: @@ -801,9 +941,44 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): - def __init__(self, version=0): + def __init__(self, version=0, config=None): super().__init__() + if config is None: + config = self.guess_config(version) + + self.timestep_conditioning = config.get("timestep_conditioning", False) + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + + self.encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + self.decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=self.timestep_conditioning, + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + self.per_channel_statistics = processor() + + def guess_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -830,7 +1005,7 @@ class VideoVAE(nn.Module): "use_quant_conv": False, "causal_decoder": False, } - else: + elif version == 1: config = { "_class_name": "CausalVideoAutoencoder", "dims": 3, @@ -866,37 +1041,47 @@ class VideoVAE(nn.Module): "causal_decoder": False, "timestep_conditioning": True, } - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - - self.encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))), - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - ) - - self.decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - causal=config.get("causal_decoder", False), - timestep_conditioning=config.get("timestep_conditioning", False), - ) - - self.timestep_conditioning = config.get("timestep_conditioning", False) - self.per_channel_statistics = processor() + else: + config = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "encoder_blocks": [ + ["res_x", {"num_layers": 4}], + ["compress_space_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}] + ], + "decoder_blocks": [ + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}] + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True + } + return config def encode(self, x): + frames_count = x.shape[2] + if ((frames_count - 1) % 8) != 0: + raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index 52df4ee22..b4026b14f 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -17,7 +17,11 @@ def make_conv_nd( groups=1, bias=True, causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", ): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") if dims == 2: return ops.Conv2d( in_channels=in_channels, @@ -28,6 +32,7 @@ def make_conv_nd( dilation=dilation, groups=groups, bias=bias, + padding_mode=spatial_padding_mode, ) elif dims == 3: if causal: @@ -40,6 +45,7 @@ def make_conv_nd( dilation=dilation, groups=groups, bias=bias, + spatial_padding_mode=spatial_padding_mode, ) return ops.Conv3d( in_channels=in_channels, @@ -50,6 +56,7 @@ def make_conv_nd( dilation=dilation, groups=groups, bias=bias, + padding_mode=spatial_padding_mode, ) elif dims == (2, 1): return DualConv3d( @@ -59,6 +66,7 @@ def make_conv_nd( stride=stride, padding=padding, bias=bias, + padding_mode=spatial_padding_mode, ) else: raise ValueError(f"unsupported dimensions: {dims}") diff --git a/comfy/ldm/lightricks/vae/dual_conv3d.py b/comfy/ldm/lightricks/vae/dual_conv3d.py index 6bd54c0a6..dcf889296 100644 --- a/comfy/ldm/lightricks/vae/dual_conv3d.py +++ b/comfy/ldm/lightricks/vae/dual_conv3d.py @@ -18,11 +18,13 @@ class DualConv3d(nn.Module): dilation: Union[int, Tuple[int, int, int]] = 1, groups=1, bias=True, + padding_mode="zeros", ): super(DualConv3d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels + self.padding_mode = padding_mode # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) @@ -108,6 +110,7 @@ class DualConv3d(nn.Module): self.padding1, self.dilation1, self.groups, + padding_mode=self.padding_mode, ) if skip_time_conv: @@ -122,6 +125,7 @@ class DualConv3d(nn.Module): self.padding2, self.dilation2, self.groups, + padding_mode=self.padding_mode, ) return x @@ -137,7 +141,16 @@ class DualConv3d(nn.Module): stride1 = (self.stride1[1], self.stride1[2]) padding1 = (self.padding1[1], self.padding1[2]) dilation1 = (self.dilation1[1], self.dilation1[2]) - x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) _, _, h, w = x.shape @@ -154,7 +167,16 @@ class DualConv3d(nn.Module): stride2 = self.stride2[0] padding2 = self.padding2[0] dilation2 = self.dilation2[0] - x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index cddc4663e..07fd2db43 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -161,9 +161,13 @@ class BaseModel(torch.nn.Module): extra = extra.to(dtype) extra_conds[o] = extra + t = self.process_timestep(t, x=x, **extra_conds) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) + def process_timestep(self, timestep, **kwargs): + return timestep + def get_dtype(self): return self.diffusion_model.dtype @@ -855,17 +859,26 @@ class LTXV(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - guiding_latent = kwargs.get("guiding_latent", None) - if guiding_latent is not None: - out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent) - - guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None) - if guiding_latent_noise_scale is not None: - out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale) - out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + return out + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f149a4bf7..1aef549f4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,3 +1,4 @@ +import json import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross return None -def detect_unet_config(state_dict, key_prefix): +def detect_unet_config(state_dict, key_prefix, metadata=None): state_dict_keys = list(state_dict.keys()) if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model @@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} dit_config["image_model"] = "ltxv" + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) return dit_config if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt @@ -454,8 +457,8 @@ def model_config_from_unet_config(unet_config, state_dict=None): logging.error("no match {}".format(unet_config)) return None -def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None): + unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) if unet_config is None: return None model_config = model_config_from_unet_config(unet_config, state_dict) diff --git a/comfy/sd.py b/comfy/sd.py index b866c66c4..fd98585a1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,5 @@ from __future__ import annotations +import json import torch from enum import Enum import logging @@ -249,7 +250,7 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None): + def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -357,7 +358,12 @@ class VAE: version = 0 elif tensor_conv1.shape[0] == 1024: version = 1 - self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version) + if "encoder.down_blocks.1.conv.conv.bias" in sd: + version = 2 + vae_config = None + if metadata is not None and "config" in metadata: + vae_config = json.loads(metadata["config"]).get("vae", None) + self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) @@ -873,13 +879,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): - sd = comfy.utils.load_torch_file(ckpt_path) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): clip = None clipvision = None vae = None @@ -891,7 +897,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: return None @@ -920,7 +926,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd) + vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: clip_target = model_config.clip_target(state_dict=sd) diff --git a/comfy/utils.py b/comfy/utils.py index df7057c6a..a826e41bf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,12 +46,18 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") -def load_torch_file(ckpt, safe_load=False, device=None): +def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: device = torch.device("cpu") + metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - sd = safetensors.torch.load_file(ckpt, device=device.type) + with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: + sd = {} + for k in f.keys(): + sd[k] = f.get_tensor(k) + if return_metadata: + metadata = f.metadata() except Exception as e: if len(e.args) > 0: message = e.args[0] @@ -77,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None): sd = pl_sd else: sd = pl_sd - return sd + return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): if metadata is not None: diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index dec912416..8bd548bcd 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,9 +1,14 @@ +import io import nodes import node_helpers import torch import comfy.model_management import comfy.model_sampling +import comfy.utils import math +import numpy as np +import av +from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords class EmptyLTXVLatentVideo: @classmethod @@ -33,7 +38,6 @@ class LTXVImgToVideo: "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."}) }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @@ -42,16 +46,220 @@ class LTXVImgToVideo: CATEGORY = "conditioning/video_models" FUNCTION = "generate" - def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale): + def generate(self, positive, negative, image, vae, width, height, length, batch_size): pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) - positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale}) - negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale}) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent[:, :, :t.shape[2]] = t - return (positive, negative, {"samples": latent}, ) + + conditioning_latent_frames_mask = torch.ones( + (batch_size, 1, latent.shape[2], 1, 1), + dtype=torch.float32, + device=latent.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0 + + return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) + + +def conditioning_get_any_value(conditioning, key, default=None): + for t in conditioning: + if key in t[1]: + return t[1][key] + return default + + +def get_noise_mask(latent): + noise_mask = latent.get("noise_mask", None) + latent_image = latent["samples"] + if noise_mask is None: + batch_size, _, latent_length, _, _ = latent_image.shape + noise_mask = torch.ones( + (batch_size, 1, latent_length, 1, 1), + dtype=torch.float32, + device=latent_image.device, + ) + else: + noise_mask = noise_mask.clone() + return noise_mask + +def get_keyframe_idxs(cond): + keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) + if keyframe_idxs is None: + return None, 0 + num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + return keyframe_idxs, num_keyframes + +class LTXVAddGuide: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE",), + "latent": ("LATENT",), + "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \ + "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), + "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, + "tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \ + "If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \ + "Negative values are counted from the end of the video."}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + CATEGORY = "conditioning/video_models" + FUNCTION = "generate" + + def __init__(self): + self._num_prefix_frames = 2 + self._patchifier = SymmetricPatchifier(1) + + def encode(self, vae, latent_width, latent_height, images, scale_factors): + time_scale_factor, width_scale_factor, height_scale_factor = scale_factors + images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + return encode_pixels, t + + def get_latent_index(self, cond, latent_length, frame_idx, scale_factors): + time_scale_factor, _, _ = scale_factors + _, num_keyframes = get_keyframe_idxs(cond) + latent_count = latent_length - num_keyframes + frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0) + frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 + + latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor + + return frame_idx, latent_idx + + def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): + keyframe_idxs, _ = get_keyframe_idxs(cond) + _, latent_coords = self._patchifier.patchify(guiding_latent) + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True) + pixel_coords[:, 0] += frame_idx + if keyframe_idxs is None: + keyframe_idxs = pixel_coords + else: + keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) + return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) + + def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): + positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) + negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + + latent_image = torch.cat([latent_image, guiding_latent], dim=2) + noise_mask = torch.cat([noise_mask, mask], dim=2) + return positive, negative, latent_image, noise_mask + + def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): + cond_length = guiding_latent.shape[2] + assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." + + mask = torch.full( + (noise_mask.shape[0], 1, cond_length, 1, 1), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + + latent_image = latent_image.clone() + noise_mask = noise_mask.clone() + + latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent + noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask + + return latent_image, noise_mask + + def generate(self, positive, negative, vae, latent, image, frame_idx, strength): + scale_factors = vae.downscale_index_formula + latent_image = latent["samples"] + noise_mask = get_noise_mask(latent) + + _, _, latent_length, latent_height, latent_width = latent_image.shape + image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) + + frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) + assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." + + if frame_idx == 0: + latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength) + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + + num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) + + positive, negative, latent_image, noise_mask = self.append_keyframe( + positive, + negative, + frame_idx, + latent_image, + noise_mask, + t[:, :, :num_prefix_frames], + strength, + scale_factors, + ) + + latent_idx += num_prefix_frames + + t = t[:, :, num_prefix_frames:] + if t.shape[2] == 0: + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + latent_image, noise_mask = self.replace_latent_frames( + latent_image, + noise_mask, + t, + latent_idx, + strength, + ) + + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + +class LTXVCropGuides: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent": ("LATENT",), + } + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + CATEGORY = "conditioning/video_models" + FUNCTION = "crop" + + def __init__(self): + self._patchifier = SymmetricPatchifier(1) + + def crop(self, positive, negative, latent): + latent_image = latent["samples"].clone() + noise_mask = get_noise_mask(latent) + + _, num_keyframes = get_keyframe_idxs(positive) + + latent_image = latent_image[:, :, :-num_keyframes] + noise_mask = noise_mask[:, :, :-num_keyframes] + + positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) + negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) class LTXVConditioning: @@ -174,6 +382,78 @@ class LTXVScheduler: return (sigmas,) +def encode_single_frame(output_file, image_array: np.ndarray, crf): + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream( + "h264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + ) + stream.height = image_array.shape[0] + stream.width = image_array.shape[1] + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( + format="yuv420p" + ) + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def decode_single_frame(video_file): + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def preprocess(image: torch.Tensor, crf=29): + if crf == 0: + return image + + image_array = (image * 255.0).byte().cpu().numpy() + with io.BytesIO() as output_file: + encode_single_frame(output_file, image_array, crf) + video_bytes = output_file.getvalue() + with io.BytesIO(video_bytes) as video_file: + image_array = decode_single_frame(video_file) + tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 + return tensor + + +class LTXVPreprocess: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "img_compression": ( + "INT", + { + "default": 35, + "min": 0, + "max": 100, + "tooltip": "Amount of compression to apply on image.", + }, + ), + } + } + + FUNCTION = "preprocess" + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("output_image",) + CATEGORY = "image" + + def preprocess(self, image, img_compression): + output_image = image + if img_compression > 0: + output_image = torch.zeros_like(image) + for i in range(image.shape[0]): + output_image[i] = preprocess(image[i], img_compression) + return (output_image,) + NODE_CLASS_MAPPINGS = { "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, @@ -181,4 +461,7 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingLTXV": ModelSamplingLTXV, "LTXVConditioning": LTXVConditioning, "LTXVScheduler": LTXVScheduler, + "LTXVAddGuide": LTXVAddGuide, + "LTXVPreprocess": LTXVPreprocess, + "LTXVCropGuides": LTXVCropGuides, }