mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support new LTXV VAE.
This commit is contained in:
parent
cac68ca813
commit
418eb7062d
@ -6,7 +6,9 @@ from einops import rearrange
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -236,6 +238,7 @@ class Decoder(nn.Module):
|
|||||||
patch_size: int = 1,
|
patch_size: int = 1,
|
||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@ -250,6 +253,8 @@ class Decoder(nn.Module):
|
|||||||
block_params = block_params if isinstance(block_params, dict) else {}
|
block_params = block_params if isinstance(block_params, dict) else {}
|
||||||
if block_name == "res_x_y":
|
if block_name == "res_x_y":
|
||||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||||
|
if block_name == "compress_all":
|
||||||
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
|
||||||
self.conv_in = make_conv_nd(
|
self.conv_in = make_conv_nd(
|
||||||
dims,
|
dims,
|
||||||
@ -276,6 +281,19 @@ class Decoder(nn.Module):
|
|||||||
resnet_eps=1e-6,
|
resnet_eps=1e-6,
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
)
|
||||||
|
elif block_name == "attn_res_x":
|
||||||
|
block = UNetMidBlock3D(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
num_layers=block_params["num_layers"],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
attention_head_dim=block_params["attention_head_dim"],
|
||||||
)
|
)
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||||
@ -286,6 +304,8 @@ class Decoder(nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
groups=norm_num_groups,
|
groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
|
timestep_conditioning=False,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
@ -296,11 +316,13 @@ class Decoder(nn.Module):
|
|||||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=block_params.get("residual", False),
|
residual=block_params.get("residual", False),
|
||||||
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown layer: {block_name}")
|
raise ValueError(f"unknown layer: {block_name}")
|
||||||
@ -323,27 +345,75 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.timestep_scale_multiplier = nn.Parameter(
|
||||||
|
torch.tensor(1000.0, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
|
output_channel * 2, 0, operations=ops,
|
||||||
|
)
|
||||||
|
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
# assert target_shape is not None, "target_shape must be provided"
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
sample = self.conv_in(sample, causal=self.causal)
|
sample = self.conv_in(sample, causal=self.causal)
|
||||||
|
|
||||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
|
||||||
|
|
||||||
checkpoint_fn = (
|
checkpoint_fn = (
|
||||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||||
if self.gradient_checkpointing and self.training
|
if self.gradient_checkpointing and self.training
|
||||||
else lambda x: x
|
else lambda x: x
|
||||||
)
|
)
|
||||||
|
|
||||||
sample = sample.to(upscale_dtype)
|
scaled_timestep = None
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
assert (
|
||||||
|
timestep is not None
|
||||||
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
|
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
|
sample = checkpoint_fn(up_block)(
|
||||||
|
sample, causal=self.causal, timestep=scaled_timestep
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||||
|
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
|
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
embedded_timestep = self.last_time_embedder(
|
||||||
|
timestep=scaled_timestep.flatten(),
|
||||||
|
resolution=None,
|
||||||
|
aspect_ratio=None,
|
||||||
|
batch_size=sample.shape[0],
|
||||||
|
hidden_dtype=sample.dtype,
|
||||||
|
)
|
||||||
|
embedded_timestep = embedded_timestep.view(
|
||||||
|
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
||||||
|
)
|
||||||
|
ada_values = self.last_scale_shift_table[
|
||||||
|
None, ..., None, None, None
|
||||||
|
] + embedded_timestep.reshape(
|
||||||
|
batch_size,
|
||||||
|
2,
|
||||||
|
-1,
|
||||||
|
embedded_timestep.shape[-3],
|
||||||
|
embedded_timestep.shape[-2],
|
||||||
|
embedded_timestep.shape[-1],
|
||||||
|
)
|
||||||
|
shift, scale = ada_values.unbind(dim=1)
|
||||||
|
sample = sample * (1 + scale) + shift
|
||||||
|
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
|
|
||||||
@ -379,12 +449,21 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
resnet_eps: float = 1e-6,
|
resnet_eps: float = 1e-6,
|
||||||
resnet_groups: int = 32,
|
resnet_groups: int = 32,
|
||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
|
inject_noise: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnet_groups = (
|
resnet_groups = (
|
||||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
|
in_channels * 4, 0, operations=ops,
|
||||||
|
)
|
||||||
|
|
||||||
self.res_blocks = nn.ModuleList(
|
self.res_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResnetBlock3D(
|
ResnetBlock3D(
|
||||||
@ -395,25 +474,48 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
groups=resnet_groups,
|
groups=resnet_groups,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=inject_noise,
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
timestep_embed = None
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
assert (
|
||||||
|
timestep is not None
|
||||||
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
timestep_embed = self.time_embedder(
|
||||||
|
timestep=timestep.flatten(),
|
||||||
|
resolution=None,
|
||||||
|
aspect_ratio=None,
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
timestep_embed = timestep_embed.view(
|
||||||
|
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
for resnet in self.res_blocks:
|
for resnet in self.res_blocks:
|
||||||
hidden_states = resnet(hidden_states, causal=causal)
|
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpaceUpsample(nn.Module):
|
class DepthToSpaceUpsample(nn.Module):
|
||||||
def __init__(self, dims, in_channels, stride, residual=False):
|
def __init__(
|
||||||
|
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.out_channels = math.prod(stride) * in_channels
|
self.out_channels = (
|
||||||
|
math.prod(stride) * in_channels // out_channels_reduction_factor
|
||||||
|
)
|
||||||
self.conv = make_conv_nd(
|
self.conv = make_conv_nd(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -423,8 +525,9 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
|
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True):
|
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||||
if self.residual:
|
if self.residual:
|
||||||
# Reshape and duplicate the input to match the output shape
|
# Reshape and duplicate the input to match the output shape
|
||||||
x_in = rearrange(
|
x_in = rearrange(
|
||||||
@ -434,7 +537,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
p2=self.stride[1],
|
p2=self.stride[1],
|
||||||
p3=self.stride[2],
|
p3=self.stride[2],
|
||||||
)
|
)
|
||||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||||
|
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||||
if self.stride[0] == 2:
|
if self.stride[0] == 2:
|
||||||
x_in = x_in[:, :, 1:, :, :]
|
x_in = x_in[:, :, 1:, :, :]
|
||||||
x = self.conv(x, causal=causal)
|
x = self.conv(x, causal=causal)
|
||||||
@ -451,7 +555,6 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
x = x + x_in
|
x = x + x_in
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -486,11 +589,14 @@ class ResnetBlock3D(nn.Module):
|
|||||||
groups: int = 32,
|
groups: int = 32,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
|
inject_noise: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.inject_noise = inject_noise
|
||||||
|
|
||||||
if norm_layer == "group_norm":
|
if norm_layer == "group_norm":
|
||||||
self.norm1 = nn.GroupNorm(
|
self.norm1 = nn.GroupNorm(
|
||||||
@ -513,6 +619,9 @@ class ResnetBlock3D(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inject_noise:
|
||||||
|
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||||
|
|
||||||
if norm_layer == "group_norm":
|
if norm_layer == "group_norm":
|
||||||
self.norm2 = nn.GroupNorm(
|
self.norm2 = nn.GroupNorm(
|
||||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||||
@ -534,6 +643,9 @@ class ResnetBlock3D(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inject_noise:
|
||||||
|
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||||
|
|
||||||
self.conv_shortcut = (
|
self.conv_shortcut = (
|
||||||
make_linear_nd(
|
make_linear_nd(
|
||||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||||
@ -548,29 +660,84 @@ class ResnetBlock3D(nn.Module):
|
|||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.scale_shift_table = nn.Parameter(
|
||||||
|
torch.randn(4, in_channels) / in_channels**0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
def _feed_spatial_noise(
|
||||||
|
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
spatial_shape = hidden_states.shape[-2:]
|
||||||
|
device = hidden_states.device
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
# similar to the "explicit noise inputs" method in style-gan
|
||||||
|
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
||||||
|
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
||||||
|
hidden_states = hidden_states + scaled_noise
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_tensor: torch.FloatTensor,
|
input_tensor: torch.FloatTensor,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
timestep: Optional[torch.Tensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
hidden_states = input_tensor
|
hidden_states = input_tensor
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
hidden_states = self.norm1(hidden_states)
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
assert (
|
||||||
|
timestep is not None
|
||||||
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
|
ada_values = self.scale_shift_table[
|
||||||
|
None, ..., None, None, None
|
||||||
|
] + timestep.reshape(
|
||||||
|
batch_size,
|
||||||
|
4,
|
||||||
|
-1,
|
||||||
|
timestep.shape[-3],
|
||||||
|
timestep.shape[-2],
|
||||||
|
timestep.shape[-1],
|
||||||
|
)
|
||||||
|
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states * (1 + scale1) + shift1
|
||||||
|
|
||||||
hidden_states = self.non_linearity(hidden_states)
|
hidden_states = self.non_linearity(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||||
|
|
||||||
|
if self.inject_noise:
|
||||||
|
hidden_states = self._feed_spatial_noise(
|
||||||
|
hidden_states, self.per_channel_scale1
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.norm2(hidden_states)
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
hidden_states = hidden_states * (1 + scale2) + shift2
|
||||||
|
|
||||||
hidden_states = self.non_linearity(hidden_states)
|
hidden_states = self.non_linearity(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||||
|
|
||||||
|
if self.inject_noise:
|
||||||
|
hidden_states = self._feed_spatial_noise(
|
||||||
|
hidden_states, self.per_channel_scale2
|
||||||
|
)
|
||||||
|
|
||||||
input_tensor = self.norm3(input_tensor)
|
input_tensor = self.norm3(input_tensor)
|
||||||
|
|
||||||
|
batch_size = input_tensor.shape[0]
|
||||||
|
|
||||||
input_tensor = self.conv_shortcut(input_tensor)
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
output_tensor = input_tensor + hidden_states
|
output_tensor = input_tensor + hidden_states
|
||||||
@ -634,33 +801,71 @@ class processor(nn.Module):
|
|||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, version=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = {
|
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
if version == 0:
|
||||||
"dims": 3,
|
config = {
|
||||||
"in_channels": 3,
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
"out_channels": 3,
|
"dims": 3,
|
||||||
"latent_channels": 128,
|
"in_channels": 3,
|
||||||
"blocks": [
|
"out_channels": 3,
|
||||||
["res_x", 4],
|
"latent_channels": 128,
|
||||||
["compress_all", 1],
|
"blocks": [
|
||||||
["res_x_y", 1],
|
["res_x", 4],
|
||||||
["res_x", 3],
|
["compress_all", 1],
|
||||||
["compress_all", 1],
|
["res_x_y", 1],
|
||||||
["res_x_y", 1],
|
["res_x", 3],
|
||||||
["res_x", 3],
|
["compress_all", 1],
|
||||||
["compress_all", 1],
|
["res_x_y", 1],
|
||||||
["res_x", 3],
|
["res_x", 3],
|
||||||
["res_x", 4],
|
["compress_all", 1],
|
||||||
],
|
["res_x", 3],
|
||||||
"scaling_factor": 1.0,
|
["res_x", 4],
|
||||||
"norm_layer": "pixel_norm",
|
],
|
||||||
"patch_size": 4,
|
"scaling_factor": 1.0,
|
||||||
"latent_log_var": "uniform",
|
"norm_layer": "pixel_norm",
|
||||||
"use_quant_conv": False,
|
"patch_size": 4,
|
||||||
"causal_decoder": False,
|
"latent_log_var": "uniform",
|
||||||
}
|
"use_quant_conv": False,
|
||||||
|
"causal_decoder": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {
|
||||||
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
|
"dims": 3,
|
||||||
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
|
"latent_channels": 128,
|
||||||
|
"decoder_blocks": [
|
||||||
|
["res_x", {"num_layers": 5, "inject_noise": True}],
|
||||||
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 6, "inject_noise": True}],
|
||||||
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 7, "inject_noise": True}],
|
||||||
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 8, "inject_noise": False}]
|
||||||
|
],
|
||||||
|
"encoder_blocks": [
|
||||||
|
["res_x", {"num_layers": 4}],
|
||||||
|
["compress_all", {}],
|
||||||
|
["res_x_y", 1],
|
||||||
|
["res_x", {"num_layers": 3}],
|
||||||
|
["compress_all", {}],
|
||||||
|
["res_x_y", 1],
|
||||||
|
["res_x", {"num_layers": 3}],
|
||||||
|
["compress_all", {}],
|
||||||
|
["res_x", {"num_layers": 3}],
|
||||||
|
["res_x", {"num_layers": 4}]
|
||||||
|
],
|
||||||
|
"scaling_factor": 1.0,
|
||||||
|
"norm_layer": "pixel_norm",
|
||||||
|
"patch_size": 4,
|
||||||
|
"latent_log_var": "uniform",
|
||||||
|
"use_quant_conv": False,
|
||||||
|
"causal_decoder": False,
|
||||||
|
"timestep_conditioning": True,
|
||||||
|
}
|
||||||
|
|
||||||
double_z = config.get("double_z", True)
|
double_z = config.get("double_z", True)
|
||||||
latent_log_var = config.get(
|
latent_log_var = config.get(
|
||||||
@ -671,7 +876,7 @@ class VideoVAE(nn.Module):
|
|||||||
dims=config["dims"],
|
dims=config["dims"],
|
||||||
in_channels=config.get("in_channels", 3),
|
in_channels=config.get("in_channels", 3),
|
||||||
out_channels=config["latent_channels"],
|
out_channels=config["latent_channels"],
|
||||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
latent_log_var=latent_log_var,
|
latent_log_var=latent_log_var,
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
@ -681,18 +886,22 @@ class VideoVAE(nn.Module):
|
|||||||
dims=config["dims"],
|
dims=config["dims"],
|
||||||
in_channels=config["latent_channels"],
|
in_channels=config["latent_channels"],
|
||||||
out_channels=config.get("out_channels", 3),
|
out_channels=config.get("out_channels", 3),
|
||||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
causal=config.get("causal_decoder", False),
|
causal=config.get("causal_decoder", False),
|
||||||
|
timestep_conditioning=config.get("timestep_conditioning", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
if self.timestep_conditioning: #TODO: seed
|
||||||
|
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||||
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||||
|
|
||||||
|
@ -340,7 +340,13 @@ class VAE:
|
|||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||||
self.working_dtypes = [torch.float16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||||
|
version = 0
|
||||||
|
if tensor_conv1.shape[0] == 512:
|
||||||
|
version = 0
|
||||||
|
elif tensor_conv1.shape[0] == 1024:
|
||||||
|
version = 1
|
||||||
|
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
|
||||||
self.latent_channels = 128
|
self.latent_channels = 128
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user