mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 21:47:07 +00:00
Support LTXV 0.9.5.
Credits: Lightricks team.
This commit is contained in:
parent
745b13649b
commit
93fedd92fe
@ -7,7 +7,7 @@ from einops import rearrange
|
|||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
positional_embedding_theta=10000.0,
|
positional_embedding_theta=10000.0,
|
||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
|
self.vae_scale_factors = vae_scale_factors
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.out_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
|
||||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
indices_grid = self.patchifier.get_grid(
|
|
||||||
orig_num_frames=x.shape[2],
|
|
||||||
orig_height=x.shape[3],
|
|
||||||
orig_width=x.shape[4],
|
|
||||||
batch_size=x.shape[0],
|
|
||||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if guiding_latent is not None:
|
|
||||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
|
||||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
|
||||||
ts *= input_ts
|
|
||||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
|
||||||
timestep = self.patchifier.patchify(ts)
|
|
||||||
input_x = x.clone()
|
|
||||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
|
||||||
if guiding_latent_noise_scale > 0:
|
|
||||||
if self.generator is None:
|
|
||||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
|
||||||
elif self.generator.device != x.device:
|
|
||||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
|
||||||
|
|
||||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
|
||||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
|
||||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
|
||||||
|
|
||||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
|
||||||
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
|
||||||
x = self.patchifier.patchify(x)
|
x, latent_coords = self.patchifier.patchify(x)
|
||||||
|
pixel_coords = latent_to_pixel_coords(
|
||||||
|
latent_coords=latent_coords,
|
||||||
|
scale_factors=self.vae_scale_factors,
|
||||||
|
causal_fix=self.causal_temporal_positioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||||
|
|
||||||
|
fractional_coords = pixel_coords.to(torch.float32)
|
||||||
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
timestep = timestep * 1000.0
|
||||||
@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||||
|
|
||||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module):
|
|||||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
if guiding_latent is not None:
|
|
||||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
|
||||||
|
|
||||||
# print("res", x)
|
|
||||||
return x
|
return x
|
||||||
|
@ -6,16 +6,29 @@ from einops import rearrange
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
def latent_to_pixel_coords(
|
||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
|
||||||
dims_to_append = target_dims - x.ndim
|
) -> Tensor:
|
||||||
if dims_to_append < 0:
|
"""
|
||||||
raise ValueError(
|
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
||||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
configuration.
|
||||||
|
Args:
|
||||||
|
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
||||||
|
containing the latent corner coordinates of each token.
|
||||||
|
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
|
||||||
|
causal_fix (bool): Whether to take into account the different temporal scale
|
||||||
|
of the first frame. Default = False for backwards compatibility.
|
||||||
|
Returns:
|
||||||
|
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||||
|
"""
|
||||||
|
pixel_coords = (
|
||||||
|
latent_coords
|
||||||
|
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||||
)
|
)
|
||||||
elif dims_to_append == 0:
|
if causal_fix:
|
||||||
return x
|
# Fix temporal scale for first frame to 1 due to causality
|
||||||
return x[(...,) + (None,) * dims_to_append]
|
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
|
return pixel_coords
|
||||||
|
|
||||||
|
|
||||||
class Patchifier(ABC):
|
class Patchifier(ABC):
|
||||||
@ -44,29 +57,26 @@ class Patchifier(ABC):
|
|||||||
def patch_size(self):
|
def patch_size(self):
|
||||||
return self._patch_size
|
return self._patch_size
|
||||||
|
|
||||||
def get_grid(
|
def get_latent_coords(
|
||||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
||||||
):
|
):
|
||||||
f = orig_num_frames // self._patch_size[0]
|
"""
|
||||||
h = orig_height // self._patch_size[1]
|
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
||||||
w = orig_width // self._patch_size[2]
|
top-left corner latent coordinates of each latent patch.
|
||||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
The tensor is repeated for each batch element.
|
||||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
"""
|
||||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
latent_sample_coords = torch.meshgrid(
|
||||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
||||||
grid = torch.stack(grid, dim=0)
|
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
||||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||||
|
indexing="ij",
|
||||||
if scale_grid is not None:
|
)
|
||||||
for i in range(3):
|
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||||
if isinstance(scale_grid[i], Tensor):
|
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
latent_coords = rearrange(
|
||||||
else:
|
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
scale = scale_grid[i]
|
)
|
||||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
return latent_coords
|
||||||
|
|
||||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
|
||||||
return grid
|
|
||||||
|
|
||||||
|
|
||||||
class SymmetricPatchifier(Patchifier):
|
class SymmetricPatchifier(Patchifier):
|
||||||
@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
self,
|
self,
|
||||||
latents: Tensor,
|
latents: Tensor,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
b, _, f, h, w = latents.shape
|
||||||
|
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
||||||
latents = rearrange(
|
latents = rearrange(
|
||||||
latents,
|
latents,
|
||||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||||
@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
p2=self._patch_size[1],
|
p2=self._patch_size[1],
|
||||||
p3=self._patch_size[2],
|
p3=self._patch_size[2],
|
||||||
)
|
)
|
||||||
return latents
|
return latents, latent_coords
|
||||||
|
|
||||||
def unpatchify(
|
def unpatchify(
|
||||||
self,
|
self,
|
||||||
|
@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
|
|||||||
stride: Union[int, Tuple[int]] = 1,
|
stride: Union[int, Tuple[int]] = 1,
|
||||||
dilation: int = 1,
|
dilation: int = 1,
|
||||||
groups: int = 1,
|
groups: int = 1,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
padding_mode="zeros",
|
padding_mode=spatial_padding_mode,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
@ -32,7 +34,7 @@ class Encoder(nn.Module):
|
|||||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -40,12 +42,13 @@ class Encoder(nn.Module):
|
|||||||
dims: Union[int, Tuple[int, int]] = 3,
|
dims: Union[int, Tuple[int, int]] = 3,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
blocks=[("res_x", 1)],
|
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||||
base_channels: int = 128,
|
base_channels: int = 128,
|
||||||
norm_num_groups: int = 32,
|
norm_num_groups: int = 32,
|
||||||
patch_size: Union[int, Tuple[int]] = 1,
|
patch_size: Union[int, Tuple[int]] = 1,
|
||||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||||
latent_log_var: str = "per_channel",
|
latent_log_var: str = "per_channel",
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@ -65,6 +68,7 @@ class Encoder(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_blocks = nn.ModuleList([])
|
self.down_blocks = nn.ModuleList([])
|
||||||
@ -82,6 +86,7 @@ class Encoder(nn.Module):
|
|||||||
resnet_eps=1e-6,
|
resnet_eps=1e-6,
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
@ -92,6 +97,7 @@ class Encoder(nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
groups=norm_num_groups,
|
groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
block = make_conv_nd(
|
block = make_conv_nd(
|
||||||
@ -101,6 +107,7 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(2, 1, 1),
|
stride=(2, 1, 1),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
block = make_conv_nd(
|
block = make_conv_nd(
|
||||||
@ -110,6 +117,7 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(1, 2, 2),
|
stride=(1, 2, 2),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
block = make_conv_nd(
|
block = make_conv_nd(
|
||||||
@ -119,6 +127,7 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all_x_y":
|
elif block_name == "compress_all_x_y":
|
||||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
@ -129,6 +138,34 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_all_res":
|
||||||
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_space_res":
|
||||||
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_time_res":
|
||||||
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown block: {block_name}")
|
raise ValueError(f"unknown block: {block_name}")
|
||||||
@ -152,10 +189,18 @@ class Encoder(nn.Module):
|
|||||||
conv_out_channels *= 2
|
conv_out_channels *= 2
|
||||||
elif latent_log_var == "uniform":
|
elif latent_log_var == "uniform":
|
||||||
conv_out_channels += 1
|
conv_out_channels += 1
|
||||||
|
elif latent_log_var == "constant":
|
||||||
|
conv_out_channels += 1
|
||||||
elif latent_log_var != "none":
|
elif latent_log_var != "none":
|
||||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||||
self.conv_out = make_conv_nd(
|
self.conv_out = make_conv_nd(
|
||||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
dims,
|
||||||
|
output_channel,
|
||||||
|
conv_out_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -197,6 +242,15 @@ class Encoder(nn.Module):
|
|||||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||||
|
elif self.latent_log_var == "constant":
|
||||||
|
sample = sample[:, :-1, ...]
|
||||||
|
approx_ln_0 = (
|
||||||
|
-30
|
||||||
|
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
||||||
|
sample = torch.cat(
|
||||||
|
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@ -231,7 +285,7 @@ class Decoder(nn.Module):
|
|||||||
dims,
|
dims,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
blocks=[("res_x", 1)],
|
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||||
base_channels: int = 128,
|
base_channels: int = 128,
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
norm_num_groups: int = 32,
|
norm_num_groups: int = 32,
|
||||||
@ -239,6 +293,7 @@ class Decoder(nn.Module):
|
|||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@ -264,6 +319,7 @@ class Decoder(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.up_blocks = nn.ModuleList([])
|
self.up_blocks = nn.ModuleList([])
|
||||||
@ -283,6 +339,7 @@ class Decoder(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
inject_noise=block_params.get("inject_noise", False),
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "attn_res_x":
|
elif block_name == "attn_res_x":
|
||||||
block = UNetMidBlock3D(
|
block = UNetMidBlock3D(
|
||||||
@ -294,6 +351,7 @@ class Decoder(nn.Module):
|
|||||||
inject_noise=block_params.get("inject_noise", False),
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
attention_head_dim=block_params["attention_head_dim"],
|
attention_head_dim=block_params["attention_head_dim"],
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||||
@ -306,14 +364,21 @@ class Decoder(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
inject_noise=block_params.get("inject_noise", False),
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
timestep_conditioning=False,
|
timestep_conditioning=False,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
@ -323,6 +388,7 @@ class Decoder(nn.Module):
|
|||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=block_params.get("residual", False),
|
residual=block_params.get("residual", False),
|
||||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown layer: {block_name}")
|
raise ValueError(f"unknown layer: {block_name}")
|
||||||
@ -340,7 +406,13 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = make_conv_nd(
|
self.conv_out = make_conv_nd(
|
||||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
dims,
|
||||||
|
output_channel,
|
||||||
|
out_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||||
resnet_groups (`int`, *optional*, defaults to 32):
|
resnet_groups (`int`, *optional*, defaults to 32):
|
||||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||||
|
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||||
|
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||||
|
inject_noise (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to inject noise into the hidden states.
|
||||||
|
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to condition the hidden states on the timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||||
@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
inject_noise: bool = False,
|
inject_noise: bool = False,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnet_groups = (
|
resnet_groups = (
|
||||||
@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
inject_noise=inject_noise,
|
inject_noise=inject_noise,
|
||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
causal: bool = True,
|
||||||
|
timestep: Optional[torch.Tensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
timestep_embed = None
|
timestep_embed = None
|
||||||
if self.timestep_conditioning:
|
if self.timestep_conditioning:
|
||||||
@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepthDownsample(nn.Module):
|
||||||
|
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.group_size = in_channels * math.prod(stride) // out_channels
|
||||||
|
self.conv = make_conv_nd(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels // math.prod(stride),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, causal: bool = True):
|
||||||
|
if self.stride[0] == 2:
|
||||||
|
x = torch.cat(
|
||||||
|
[x[:, :, :1, :, :], x], dim=2
|
||||||
|
) # duplicate first frames for padding
|
||||||
|
|
||||||
|
# skip connection
|
||||||
|
x_in = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
||||||
|
x_in = x_in.mean(dim=2)
|
||||||
|
|
||||||
|
# conv
|
||||||
|
x = self.conv(x, causal=causal)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x + x_in
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpaceUpsample(nn.Module):
|
class DepthToSpaceUpsample(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
self,
|
||||||
|
dims,
|
||||||
|
in_channels,
|
||||||
|
stride,
|
||||||
|
residual=False,
|
||||||
|
out_channels_reduction_factor=1,
|
||||||
|
spatial_padding_mode="zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||||
@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
inject_noise: bool = False,
|
inject_noise: bool = False,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
@ -801,9 +941,44 @@ class processor(nn.Module):
|
|||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self, version=0):
|
def __init__(self, version=0, config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self.guess_config(version)
|
||||||
|
|
||||||
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
|
double_z = config.get("double_z", True)
|
||||||
|
latent_log_var = config.get(
|
||||||
|
"latent_log_var", "per_channel" if double_z else "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
dims=config["dims"],
|
||||||
|
in_channels=config.get("in_channels", 3),
|
||||||
|
out_channels=config["latent_channels"],
|
||||||
|
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||||
|
patch_size=config.get("patch_size", 1),
|
||||||
|
latent_log_var=latent_log_var,
|
||||||
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
dims=config["dims"],
|
||||||
|
in_channels=config["latent_channels"],
|
||||||
|
out_channels=config.get("out_channels", 3),
|
||||||
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
|
patch_size=config.get("patch_size", 1),
|
||||||
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
|
causal=config.get("causal_decoder", False),
|
||||||
|
timestep_conditioning=self.timestep_conditioning,
|
||||||
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
|
def guess_config(self, version):
|
||||||
if version == 0:
|
if version == 0:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
|
|||||||
"use_quant_conv": False,
|
"use_quant_conv": False,
|
||||||
"causal_decoder": False,
|
"causal_decoder": False,
|
||||||
}
|
}
|
||||||
else:
|
elif version == 1:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
"dims": 3,
|
"dims": 3,
|
||||||
@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
|
|||||||
"causal_decoder": False,
|
"causal_decoder": False,
|
||||||
"timestep_conditioning": True,
|
"timestep_conditioning": True,
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
double_z = config.get("double_z", True)
|
config = {
|
||||||
latent_log_var = config.get(
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
"latent_log_var", "per_channel" if double_z else "none"
|
"dims": 3,
|
||||||
)
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
self.encoder = Encoder(
|
"latent_channels": 128,
|
||||||
dims=config["dims"],
|
"encoder_blocks": [
|
||||||
in_channels=config.get("in_channels", 3),
|
["res_x", {"num_layers": 4}],
|
||||||
out_channels=config["latent_channels"],
|
["compress_space_res", {"multiplier": 2}],
|
||||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
["res_x", {"num_layers": 6}],
|
||||||
patch_size=config.get("patch_size", 1),
|
["compress_time_res", {"multiplier": 2}],
|
||||||
latent_log_var=latent_log_var,
|
["res_x", {"num_layers": 6}],
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
["compress_all_res", {"multiplier": 2}],
|
||||||
)
|
["res_x", {"num_layers": 2}],
|
||||||
|
["compress_all_res", {"multiplier": 2}],
|
||||||
self.decoder = Decoder(
|
["res_x", {"num_layers": 2}]
|
||||||
dims=config["dims"],
|
],
|
||||||
in_channels=config["latent_channels"],
|
"decoder_blocks": [
|
||||||
out_channels=config.get("out_channels", 3),
|
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
patch_size=config.get("patch_size", 1),
|
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
causal=config.get("causal_decoder", False),
|
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||||
timestep_conditioning=config.get("timestep_conditioning", False),
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
)
|
["res_x", {"num_layers": 5, "inject_noise": False}]
|
||||||
|
],
|
||||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
"scaling_factor": 1.0,
|
||||||
self.per_channel_statistics = processor()
|
"norm_layer": "pixel_norm",
|
||||||
|
"patch_size": 4,
|
||||||
|
"latent_log_var": "uniform",
|
||||||
|
"use_quant_conv": False,
|
||||||
|
"causal_decoder": False,
|
||||||
|
"timestep_conditioning": True
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
|
frames_count = x.shape[2]
|
||||||
|
if ((frames_count - 1) % 8) != 0:
|
||||||
|
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
|
@ -17,7 +17,11 @@ def make_conv_nd(
|
|||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
causal=False,
|
causal=False,
|
||||||
|
spatial_padding_mode="zeros",
|
||||||
|
temporal_padding_mode="zeros",
|
||||||
):
|
):
|
||||||
|
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
||||||
|
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return ops.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -28,6 +32,7 @@ def make_conv_nd(
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
if causal:
|
if causal:
|
||||||
@ -40,6 +45,7 @@ def make_conv_nd(
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
return ops.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -50,6 +56,7 @@ def make_conv_nd(
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif dims == (2, 1):
|
elif dims == (2, 1):
|
||||||
return DualConv3d(
|
return DualConv3d(
|
||||||
@ -59,6 +66,7 @@ def make_conv_nd(
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
|
|||||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
padding_mode="zeros",
|
||||||
):
|
):
|
||||||
super(DualConv3d, self).__init__()
|
super(DualConv3d, self).__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.padding_mode = padding_mode
|
||||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
|
|||||||
self.padding1,
|
self.padding1,
|
||||||
self.dilation1,
|
self.dilation1,
|
||||||
self.groups,
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_time_conv:
|
if skip_time_conv:
|
||||||
@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
|
|||||||
self.padding2,
|
self.padding2,
|
||||||
self.dilation2,
|
self.dilation2,
|
||||||
self.groups,
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
|
|||||||
stride1 = (self.stride1[1], self.stride1[2])
|
stride1 = (self.stride1[1], self.stride1[2])
|
||||||
padding1 = (self.padding1[1], self.padding1[2])
|
padding1 = (self.padding1[1], self.padding1[2])
|
||||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
x = F.conv2d(
|
||||||
|
x,
|
||||||
|
weight1,
|
||||||
|
self.bias1,
|
||||||
|
stride1,
|
||||||
|
padding1,
|
||||||
|
dilation1,
|
||||||
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
|
|
||||||
@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
|
|||||||
stride2 = self.stride2[0]
|
stride2 = self.stride2[0]
|
||||||
padding2 = self.padding2[0]
|
padding2 = self.padding2[0]
|
||||||
dilation2 = self.dilation2[0]
|
dilation2 = self.dilation2[0]
|
||||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
x = F.conv1d(
|
||||||
|
x,
|
||||||
|
weight2,
|
||||||
|
self.bias2,
|
||||||
|
stride2,
|
||||||
|
padding2,
|
||||||
|
dilation2,
|
||||||
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
)
|
||||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -161,9 +161,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra = extra.to(dtype)
|
extra = extra.to(dtype)
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, **kwargs):
|
||||||
|
return timestep
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
@ -855,17 +859,26 @@ class LTXV(BaseModel):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
guiding_latent = kwargs.get("guiding_latent", None)
|
|
||||||
if guiding_latent is not None:
|
|
||||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
|
||||||
|
|
||||||
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
|
|
||||||
if guiding_latent_noise_scale is not None:
|
|
||||||
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
|
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
|
||||||
|
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
|
if denoise_mask is None:
|
||||||
|
return timestep
|
||||||
|
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix):
|
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||||
@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxv"
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||||
@ -454,8 +457,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
|||||||
logging.error("no match {}".format(unet_config))
|
logging.error("no match {}".format(unet_config))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
||||||
if unet_config is None:
|
if unet_config is None:
|
||||||
return None
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
|
20
comfy/sd.py
20
comfy/sd.py
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
@ -249,7 +250,7 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
@ -357,7 +358,12 @@ class VAE:
|
|||||||
version = 0
|
version = 0
|
||||||
elif tensor_conv1.shape[0] == 1024:
|
elif tensor_conv1.shape[0] == 1024:
|
||||||
version = 1
|
version = 1
|
||||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
|
if "encoder.down_blocks.1.conv.conv.bias" in sd:
|
||||||
|
version = 2
|
||||||
|
vae_config = None
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
vae_config = json.loads(metadata["config"]).get("vae", None)
|
||||||
|
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
|
||||||
self.latent_channels = 128
|
self.latent_channels = 128
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
@ -873,13 +879,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -891,7 +897,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -920,7 +926,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||||
vae = VAE(sd=vae_sd)
|
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
clip_target = model_config.clip_target(state_dict=sd)
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
|
@ -46,12 +46,18 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
else:
|
else:
|
||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
metadata = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
|
sd = {}
|
||||||
|
for k in f.keys():
|
||||||
|
sd[k] = f.get_tensor(k)
|
||||||
|
if return_metadata:
|
||||||
|
metadata = f.metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 0:
|
if len(e.args) > 0:
|
||||||
message = e.args[0]
|
message = e.args[0]
|
||||||
@ -77,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
else:
|
else:
|
||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
return sd
|
return (sd, metadata) if return_metadata else sd
|
||||||
|
|
||||||
def save_torch_file(sd, ckpt, metadata=None):
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
|
import io
|
||||||
import nodes
|
import nodes
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import av
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
class EmptyLTXVLatentVideo:
|
class EmptyLTXVLatentVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -33,7 +38,6 @@ class LTXVImgToVideo:
|
|||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
@ -42,16 +46,220 @@ class LTXVImgToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
|
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
|
||||||
|
|
||||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||||
latent[:, :, :t.shape[2]] = t
|
latent[:, :, :t.shape[2]] = t
|
||||||
return (positive, negative, {"samples": latent}, )
|
|
||||||
|
conditioning_latent_frames_mask = torch.ones(
|
||||||
|
(batch_size, 1, latent.shape[2], 1, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=latent.device,
|
||||||
|
)
|
||||||
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
||||||
|
|
||||||
|
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
||||||
|
|
||||||
|
|
||||||
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
|
for t in conditioning:
|
||||||
|
if key in t[1]:
|
||||||
|
return t[1][key]
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise_mask(latent):
|
||||||
|
noise_mask = latent.get("noise_mask", None)
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
if noise_mask is None:
|
||||||
|
batch_size, _, latent_length, _, _ = latent_image.shape
|
||||||
|
noise_mask = torch.ones(
|
||||||
|
(batch_size, 1, latent_length, 1, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=latent_image.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_mask = noise_mask.clone()
|
||||||
|
return noise_mask
|
||||||
|
|
||||||
|
def get_keyframe_idxs(cond):
|
||||||
|
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is None:
|
||||||
|
return None, 0
|
||||||
|
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||||
|
return keyframe_idxs, num_keyframes
|
||||||
|
|
||||||
|
class LTXVAddGuide:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"latent": ("LATENT",),
|
||||||
|
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
|
||||||
|
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||||
|
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||||
|
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
|
||||||
|
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
|
||||||
|
"Negative values are counted from the end of the video."}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._num_prefix_frames = 2
|
||||||
|
self._patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
|
def encode(self, vae, latent_width, latent_height, images, scale_factors):
|
||||||
|
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||||
|
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||||
|
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
||||||
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
return encode_pixels, t
|
||||||
|
|
||||||
|
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
|
||||||
|
time_scale_factor, _, _ = scale_factors
|
||||||
|
_, num_keyframes = get_keyframe_idxs(cond)
|
||||||
|
latent_count = latent_length - num_keyframes
|
||||||
|
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
|
||||||
|
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||||
|
|
||||||
|
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||||
|
|
||||||
|
return frame_idx, latent_idx
|
||||||
|
|
||||||
|
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
|
||||||
|
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||||
|
_, latent_coords = self._patchifier.patchify(guiding_latent)
|
||||||
|
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
|
||||||
|
pixel_coords[:, 0] += frame_idx
|
||||||
|
if keyframe_idxs is None:
|
||||||
|
keyframe_idxs = pixel_coords
|
||||||
|
else:
|
||||||
|
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
|
||||||
|
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||||
|
|
||||||
|
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||||
|
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||||
|
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
|
mask = torch.full(
|
||||||
|
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
||||||
|
1.0 - strength,
|
||||||
|
dtype=noise_mask.dtype,
|
||||||
|
device=noise_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||||
|
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||||
|
return positive, negative, latent_image, noise_mask
|
||||||
|
|
||||||
|
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
|
||||||
|
cond_length = guiding_latent.shape[2]
|
||||||
|
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
|
mask = torch.full(
|
||||||
|
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||||
|
1.0 - strength,
|
||||||
|
dtype=noise_mask.dtype,
|
||||||
|
device=noise_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_image = latent_image.clone()
|
||||||
|
noise_mask = noise_mask.clone()
|
||||||
|
|
||||||
|
latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
|
||||||
|
noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
|
||||||
|
|
||||||
|
return latent_image, noise_mask
|
||||||
|
|
||||||
|
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
|
||||||
|
scale_factors = vae.downscale_index_formula
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
noise_mask = get_noise_mask(latent)
|
||||||
|
|
||||||
|
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||||
|
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||||
|
|
||||||
|
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
|
||||||
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
|
if frame_idx == 0:
|
||||||
|
latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength)
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
|
||||||
|
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||||
|
|
||||||
|
positive, negative, latent_image, noise_mask = self.append_keyframe(
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
frame_idx,
|
||||||
|
latent_image,
|
||||||
|
noise_mask,
|
||||||
|
t[:, :, :num_prefix_frames],
|
||||||
|
strength,
|
||||||
|
scale_factors,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_idx += num_prefix_frames
|
||||||
|
|
||||||
|
t = t[:, :, num_prefix_frames:]
|
||||||
|
if t.shape[2] == 0:
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
latent_image, noise_mask = self.replace_latent_frames(
|
||||||
|
latent_image,
|
||||||
|
noise_mask,
|
||||||
|
t,
|
||||||
|
latent_idx,
|
||||||
|
strength,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVCropGuides:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"latent": ("LATENT",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
FUNCTION = "crop"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
|
def crop(self, positive, negative, latent):
|
||||||
|
latent_image = latent["samples"].clone()
|
||||||
|
noise_mask = get_noise_mask(latent)
|
||||||
|
|
||||||
|
_, num_keyframes = get_keyframe_idxs(positive)
|
||||||
|
|
||||||
|
latent_image = latent_image[:, :, :-num_keyframes]
|
||||||
|
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
||||||
|
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
|
||||||
class LTXVConditioning:
|
class LTXVConditioning:
|
||||||
@ -174,6 +382,78 @@ class LTXVScheduler:
|
|||||||
|
|
||||||
return (sigmas,)
|
return (sigmas,)
|
||||||
|
|
||||||
|
def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
||||||
|
container = av.open(output_file, "w", format="mp4")
|
||||||
|
try:
|
||||||
|
stream = container.add_stream(
|
||||||
|
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||||
|
)
|
||||||
|
stream.height = image_array.shape[0]
|
||||||
|
stream.width = image_array.shape[1]
|
||||||
|
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
|
||||||
|
format="yuv420p"
|
||||||
|
)
|
||||||
|
container.mux(stream.encode(av_frame))
|
||||||
|
container.mux(stream.encode())
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_single_frame(video_file):
|
||||||
|
container = av.open(video_file)
|
||||||
|
try:
|
||||||
|
stream = next(s for s in container.streams if s.type == "video")
|
||||||
|
frame = next(container.decode(stream))
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
return frame.to_ndarray(format="rgb24")
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(image: torch.Tensor, crf=29):
|
||||||
|
if crf == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
image_array = (image * 255.0).byte().cpu().numpy()
|
||||||
|
with io.BytesIO() as output_file:
|
||||||
|
encode_single_frame(output_file, image_array, crf)
|
||||||
|
video_bytes = output_file.getvalue()
|
||||||
|
with io.BytesIO(video_bytes) as video_file:
|
||||||
|
image_array = decode_single_frame(video_file)
|
||||||
|
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVPreprocess:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"img_compression": (
|
||||||
|
"INT",
|
||||||
|
{
|
||||||
|
"default": 35,
|
||||||
|
"min": 0,
|
||||||
|
"max": 100,
|
||||||
|
"tooltip": "Amount of compression to apply on image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "preprocess"
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("output_image",)
|
||||||
|
CATEGORY = "image"
|
||||||
|
|
||||||
|
def preprocess(self, image, img_compression):
|
||||||
|
output_image = image
|
||||||
|
if img_compression > 0:
|
||||||
|
output_image = torch.zeros_like(image)
|
||||||
|
for i in range(image.shape[0]):
|
||||||
|
output_image[i] = preprocess(image[i], img_compression)
|
||||||
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||||
@ -181,4 +461,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||||
"LTXVConditioning": LTXVConditioning,
|
"LTXVConditioning": LTXVConditioning,
|
||||||
"LTXVScheduler": LTXVScheduler,
|
"LTXVScheduler": LTXVScheduler,
|
||||||
|
"LTXVAddGuide": LTXVAddGuide,
|
||||||
|
"LTXVPreprocess": LTXVPreprocess,
|
||||||
|
"LTXVCropGuides": LTXVCropGuides,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user