Mochi VAE encoder.

This commit is contained in:
comfyanonymous 2024-11-01 17:33:09 -04:00
parent cc9cf6d1bd
commit fabf449feb
3 changed files with 283 additions and 28 deletions

View File

@ -2,12 +2,16 @@
#adapted to ComfyUI #adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -158,8 +162,10 @@ class ResBlock(nn.Module):
*, *,
affine: bool = True, affine: bool = True,
attn_block: Optional[nn.Module] = None, attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True, causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -170,23 +176,23 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
PConv3d( PConv3d(
in_channels=channels, in_channels=channels,
out_channels=channels, out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
# causal=causal, causal=causal,
), ),
norm_fn(channels, affine=affine), norm_fn(channels, affine=affine),
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
PConv3d( PConv3d(
in_channels=channels, in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels, out_channels=channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
# causal=causal, causal=causal,
), ),
) )
@ -206,6 +212,81 @@ class ResBlock(nn.Module):
return self.attn_block(x) return self.attn_block(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
if self.qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module): class CausalUpsampleBlock(nn.Module):
def __init__( def __init__(
self, self,
@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
return x return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs): def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code. attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
class DownsampleBlock(nn.Module): class DownsampleBlock(nn.Module):
@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
out_channels=out_channels, out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction), stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate", padding_mode="replicate",
bias=True, bias=block_kwargs["bias"],
) )
) )
@ -382,7 +459,7 @@ class Decoder(nn.Module):
blocks = [] blocks = []
first_block = [ first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer. ] # Input layer.
# First set of blocks preserve channel count. # First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]): for _ in range(num_res_blocks[-1]):
@ -452,11 +529,165 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous() return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.fourier_features = FourierFeatures()
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.fourier_features(x)
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.encoder = None #TODO once the model releases self.encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
self.decoder = Decoder( self.decoder = Decoder(
out_channels=3, out_channels=3,
base_channels=128, base_channels=128,
@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
) )
def encode(self, x): def encode(self, x):
return self.encoder(x) return self.encoder(x).mode()
def decode(self, x): def decode(self, x):
return self.decoder(x) return self.decoder(x)

View File

@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v), (q, k, v),
) )
if SDP_BATCH_LIMIT >= q.shape[0]:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
) )
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out

View File

@ -171,6 +171,7 @@ class VAE:
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@ -240,16 +241,22 @@ class VAE:
self.output_channels = 2 self.output_channels = 2
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio self.process_output = lambda audio: audio
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd: if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE() self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
self.latent_channels = 12 self.latent_channels = 12
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -361,17 +368,22 @@ class VAE:
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used)) batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device) samples = None
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")