#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI

from typing import List, Optional, Tuple, Union
from functools import partial
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from comfy.ldm.modules.attention import optimized_attention

import comfy.ops
ops = comfy.ops.disable_weight_init

# import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames


def cast_tuple(t, length=1):
    return t if isinstance(t, tuple) else ((t,) * length)


class GroupNormSpatial(ops.GroupNorm):
    """
    GroupNorm applied per-frame.
    """

    def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
        B, C, T, H, W = x.shape
        x = rearrange(x, "B C T H W -> (B T) C H W")
        # Run group norm in chunks.
        output = torch.empty_like(x)
        for b in range(0, B * T, chunk_size):
            output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
        return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)

class PConv3d(ops.Conv3d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size: Union[int, Tuple[int, int, int]],
        stride: Union[int, Tuple[int, int, int]],
        causal: bool = True,
        context_parallel: bool = True,
        **kwargs,
    ):
        self.causal = causal
        self.context_parallel = context_parallel
        kernel_size = cast_tuple(kernel_size, 3)
        stride = cast_tuple(stride, 3)
        height_pad = (kernel_size[1] - 1) // 2
        width_pad = (kernel_size[2] - 1) // 2

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=(1, 1, 1),
            padding=(0, height_pad, width_pad),
            **kwargs,
        )

    def forward(self, x: torch.Tensor):
        # Compute padding amounts.
        context_size = self.kernel_size[0] - 1
        if self.causal:
            pad_front = context_size
            pad_back = 0
        else:
            pad_front = context_size // 2
            pad_back = context_size - pad_front

        # Apply padding.
        assert self.padding_mode == "replicate"  # DEBUG
        mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
        x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
        return super().forward(x)


class Conv1x1(ops.Linear):
    """*1x1 Conv implemented with a linear layer."""

    def __init__(self, in_features: int, out_features: int, *args, **kwargs):
        super().__init__(in_features, out_features, *args, **kwargs)

    def forward(self, x: torch.Tensor):
        """Forward pass.

        Args:
            x: Input tensor. Shape: [B, C, *] or [B, *, C].

        Returns:
            x: Output tensor. Shape: [B, C', *] or [B, *, C'].
        """
        x = x.movedim(1, -1)
        x = super().forward(x)
        x = x.movedim(-1, 1)
        return x


class DepthToSpaceTime(nn.Module):
    def __init__(
        self,
        temporal_expansion: int,
        spatial_expansion: int,
    ):
        super().__init__()
        self.temporal_expansion = temporal_expansion
        self.spatial_expansion = spatial_expansion

    # When printed, this module should show the temporal and spatial expansion factors.
    def extra_repr(self):
        return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"

    def forward(self, x: torch.Tensor):
        """Forward pass.

        Args:
            x: Input tensor. Shape: [B, C, T, H, W].

        Returns:
            x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
        """
        x = rearrange(
            x,
            "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
            st=self.temporal_expansion,
            sh=self.spatial_expansion,
            sw=self.spatial_expansion,
        )

        # cp_rank, _ = cp.get_cp_rank_size()
        if self.temporal_expansion > 1: # and cp_rank == 0:
            # Drop the first self.temporal_expansion - 1 frames.
            # This is because we always want the 3x3x3 conv filter to only apply
            # to the first frame, and the first frame doesn't need to be repeated.
            assert all(x.shape)
            x = x[:, :, self.temporal_expansion - 1 :]
            assert all(x.shape)

        return x


def norm_fn(
    in_channels: int,
    affine: bool = True,
):
    return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)


class ResBlock(nn.Module):
    """Residual block that preserves the spatial dimensions."""

    def __init__(
        self,
        channels: int,
        *,
        affine: bool = True,
        attn_block: Optional[nn.Module] = None,
        causal: bool = True,
        prune_bottleneck: bool = False,
        padding_mode: str,
        bias: bool = True,
    ):
        super().__init__()
        self.channels = channels

        assert causal
        self.stack = nn.Sequential(
            norm_fn(channels, affine=affine),
            nn.SiLU(inplace=True),
            PConv3d(
                in_channels=channels,
                out_channels=channels // 2 if prune_bottleneck else channels,
                kernel_size=(3, 3, 3),
                stride=(1, 1, 1),
                padding_mode=padding_mode,
                bias=bias,
                causal=causal,
            ),
            norm_fn(channels, affine=affine),
            nn.SiLU(inplace=True),
            PConv3d(
                in_channels=channels // 2 if prune_bottleneck else channels,
                out_channels=channels,
                kernel_size=(3, 3, 3),
                stride=(1, 1, 1),
                padding_mode=padding_mode,
                bias=bias,
                causal=causal,
            ),
        )

        self.attn_block = attn_block if attn_block else nn.Identity()

    def forward(self, x: torch.Tensor):
        """Forward pass.

        Args:
            x: Input tensor. Shape: [B, C, T, H, W].
        """
        residual = x
        x = self.stack(x)
        x = x + residual
        del residual

        return self.attn_block(x)


class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        head_dim: int = 32,
        qkv_bias: bool = False,
        out_bias: bool = True,
        qk_norm: bool = True,
    ) -> None:
        super().__init__()
        self.head_dim = head_dim
        self.num_heads = dim // head_dim
        self.qk_norm = qk_norm

        self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
        self.out = nn.Linear(dim, dim, bias=out_bias)

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """Compute temporal self-attention.

        Args:
            x: Input tensor. Shape: [B, C, T, H, W].
            chunk_size: Chunk size for large tensors.

        Returns:
            x: Output tensor. Shape: [B, C, T, H, W].
        """
        B, _, T, H, W = x.shape

        if T == 1:
            # No attention for single frame.
            x = x.movedim(1, -1)  # [B, C, T, H, W] -> [B, T, H, W, C]
            qkv = self.qkv(x)
            _, _, x = qkv.chunk(3, dim=-1)  # Throw away queries and keys.
            x = self.out(x)
            return x.movedim(-1, 1)  # [B, T, H, W, C] -> [B, C, T, H, W]

        # 1D temporal attention.
        x = rearrange(x, "B C t h w -> (B h w) t C")
        qkv = self.qkv(x)

        # Input: qkv with shape [B, t, 3 * num_heads * head_dim]
        # Output: x with shape [B, num_heads, t, head_dim]
        q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)

        if self.qk_norm:
            q = F.normalize(q, p=2, dim=-1)
            k = F.normalize(k, p=2, dim=-1)

        x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)

        assert x.size(0) == q.size(0)

        x = self.out(x)
        x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
        return x


class AttentionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        **attn_kwargs,
    ) -> None:
        super().__init__()
        self.norm = norm_fn(dim)
        self.attn = Attention(dim, **attn_kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.attn(self.norm(x))


class CausalUpsampleBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_res_blocks: int,
        *,
        temporal_expansion: int = 2,
        spatial_expansion: int = 2,
        **block_kwargs,
    ):
        super().__init__()

        blocks = []
        for _ in range(num_res_blocks):
            blocks.append(block_fn(in_channels, **block_kwargs))
        self.blocks = nn.Sequential(*blocks)

        self.temporal_expansion = temporal_expansion
        self.spatial_expansion = spatial_expansion

        # Change channels in the final convolution layer.
        self.proj = Conv1x1(
            in_channels,
            out_channels * temporal_expansion * (spatial_expansion**2),
        )

        self.d2st = DepthToSpaceTime(
            temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
        )

    def forward(self, x):
        x = self.blocks(x)
        x = self.proj(x)
        x = self.d2st(x)
        return x


def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
    attn_block = AttentionBlock(channels) if has_attention else None
    return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)


class DownsampleBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_res_blocks,
        *,
        temporal_reduction=2,
        spatial_reduction=2,
        **block_kwargs,
    ):
        """
        Downsample block for the VAE encoder.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output channels.
            num_res_blocks: Number of residual blocks.
            temporal_reduction: Temporal reduction factor.
            spatial_reduction: Spatial reduction factor.
        """
        super().__init__()
        layers = []

        # Change the channel count in the strided convolution.
        # This lets the ResBlock have uniform channel count,
        # as in ConvNeXt.
        assert in_channels != out_channels
        layers.append(
            PConv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
                stride=(temporal_reduction, spatial_reduction, spatial_reduction),
                # First layer in each block always uses replicate padding
                padding_mode="replicate",
                bias=block_kwargs["bias"],
            )
        )

        for _ in range(num_res_blocks):
            layers.append(block_fn(out_channels, **block_kwargs))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
    num_freqs = (stop - start) // step
    assert inputs.ndim == 5
    C = inputs.size(1)

    # Create Base 2 Fourier features.
    freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
    assert num_freqs == len(freqs)
    w = torch.pow(2.0, freqs) * (2 * torch.pi)  # [num_freqs]
    C = inputs.shape[1]
    w = w.repeat(C)[None, :, None, None, None]  # [1, C * num_freqs, 1, 1, 1]

    # Interleaved repeat of input channels to match w.
    h = inputs.repeat_interleave(num_freqs, dim=1)  # [B, C * num_freqs, T, H, W]
    # Scale channels by frequency.
    h = w * h

    return torch.cat(
        [
            inputs,
            torch.sin(h),
            torch.cos(h),
        ],
        dim=1,
    )


class FourierFeatures(nn.Module):
    def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
        super().__init__()
        self.start = start
        self.stop = stop
        self.step = step

    def forward(self, inputs):
        """Add Fourier features to inputs.

        Args:
            inputs: Input tensor. Shape: [B, C, T, H, W]

        Returns:
            h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
        """
        return add_fourier_features(inputs, self.start, self.stop, self.step)


class Decoder(nn.Module):
    def __init__(
        self,
        *,
        out_channels: int = 3,
        latent_dim: int,
        base_channels: int,
        channel_multipliers: List[int],
        num_res_blocks: List[int],
        temporal_expansions: Optional[List[int]] = None,
        spatial_expansions: Optional[List[int]] = None,
        has_attention: List[bool],
        output_norm: bool = True,
        nonlinearity: str = "silu",
        output_nonlinearity: str = "silu",
        causal: bool = True,
        **block_kwargs,
    ):
        super().__init__()
        self.input_channels = latent_dim
        self.base_channels = base_channels
        self.channel_multipliers = channel_multipliers
        self.num_res_blocks = num_res_blocks
        self.output_nonlinearity = output_nonlinearity
        assert nonlinearity == "silu"
        assert causal

        ch = [mult * base_channels for mult in channel_multipliers]
        self.num_up_blocks = len(ch) - 1
        assert len(num_res_blocks) == self.num_up_blocks + 2

        blocks = []

        first_block = [
            ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
        ]  # Input layer.
        # First set of blocks preserve channel count.
        for _ in range(num_res_blocks[-1]):
            first_block.append(
                block_fn(
                    ch[-1],
                    has_attention=has_attention[-1],
                    causal=causal,
                    **block_kwargs,
                )
            )
        blocks.append(nn.Sequential(*first_block))

        assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
        assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2

        upsample_block_fn = CausalUpsampleBlock

        for i in range(self.num_up_blocks):
            block = upsample_block_fn(
                ch[-i - 1],
                ch[-i - 2],
                num_res_blocks=num_res_blocks[-i - 2],
                has_attention=has_attention[-i - 2],
                temporal_expansion=temporal_expansions[-i - 1],
                spatial_expansion=spatial_expansions[-i - 1],
                causal=causal,
                **block_kwargs,
            )
            blocks.append(block)

        assert not output_norm

        # Last block. Preserve channel count.
        last_block = []
        for _ in range(num_res_blocks[0]):
            last_block.append(
                block_fn(
                    ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
                )
            )
        blocks.append(nn.Sequential(*last_block))

        self.blocks = nn.ModuleList(blocks)
        self.output_proj = Conv1x1(ch[0], out_channels)

    def forward(self, x):
        """Forward pass.

        Args:
            x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].

        Returns:
            x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
               T + 1 = (t - 1) * 4.
               H = h * 16, W = w * 16.
        """
        for block in self.blocks:
            x = block(x)

        if self.output_nonlinearity == "silu":
            x = F.silu(x, inplace=not self.training)
        else:
            assert (
                not self.output_nonlinearity
            )  # StyleGAN3 omits the to-RGB nonlinearity.

        return self.output_proj(x).contiguous()

class LatentDistribution:
    def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
        """Initialize latent distribution.

        Args:
            mean: Mean of the distribution. Shape: [B, C, T, H, W].
            logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
        """
        assert mean.shape == logvar.shape
        self.mean = mean
        self.logvar = logvar

    def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
        if temperature == 0.0:
            return self.mean

        if noise is None:
            noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
        else:
            assert noise.device == self.mean.device
            noise = noise.to(self.mean.dtype)

        if temperature != 1.0:
            raise NotImplementedError(f"Temperature {temperature} is not supported.")

        # Just Gaussian sample with no scaling of variance.
        return noise * torch.exp(self.logvar * 0.5) + self.mean

    def mode(self):
        return self.mean

class Encoder(nn.Module):
    def __init__(
        self,
        *,
        in_channels: int,
        base_channels: int,
        channel_multipliers: List[int],
        num_res_blocks: List[int],
        latent_dim: int,
        temporal_reductions: List[int],
        spatial_reductions: List[int],
        prune_bottlenecks: List[bool],
        has_attentions: List[bool],
        affine: bool = True,
        bias: bool = True,
        input_is_conv_1x1: bool = False,
        padding_mode: str,
    ):
        super().__init__()
        self.temporal_reductions = temporal_reductions
        self.spatial_reductions = spatial_reductions
        self.base_channels = base_channels
        self.channel_multipliers = channel_multipliers
        self.num_res_blocks = num_res_blocks
        self.latent_dim = latent_dim

        self.fourier_features = FourierFeatures()
        ch = [mult * base_channels for mult in channel_multipliers]
        num_down_blocks = len(ch) - 1
        assert len(num_res_blocks) == num_down_blocks + 2

        layers = (
            [ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
            if not input_is_conv_1x1
            else [Conv1x1(in_channels, ch[0])]
        )

        assert len(prune_bottlenecks) == num_down_blocks + 2
        assert len(has_attentions) == num_down_blocks + 2
        block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)

        for _ in range(num_res_blocks[0]):
            layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
        prune_bottlenecks = prune_bottlenecks[1:]
        has_attentions = has_attentions[1:]

        assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
        for i in range(num_down_blocks):
            layer = DownsampleBlock(
                ch[i],
                ch[i + 1],
                num_res_blocks=num_res_blocks[i + 1],
                temporal_reduction=temporal_reductions[i],
                spatial_reduction=spatial_reductions[i],
                prune_bottleneck=prune_bottlenecks[i],
                has_attention=has_attentions[i],
                affine=affine,
                bias=bias,
                padding_mode=padding_mode,
            )

            layers.append(layer)

        # Additional blocks.
        for _ in range(num_res_blocks[-1]):
            layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))

        self.layers = nn.Sequential(*layers)

        # Output layers.
        self.output_norm = norm_fn(ch[-1])
        self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)

    @property
    def temporal_downsample(self):
        return math.prod(self.temporal_reductions)

    @property
    def spatial_downsample(self):
        return math.prod(self.spatial_reductions)

    def forward(self, x) -> LatentDistribution:
        """Forward pass.

        Args:
            x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]

        Returns:
            means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
                   h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
            logvar: Shape: [B, latent_dim, t, h, w].
        """
        assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
        x = self.fourier_features(x)

        x = self.layers(x)

        x = self.output_norm(x)
        x = F.silu(x, inplace=True)
        x = self.output_proj(x)

        means, logvar = torch.chunk(x, 2, dim=1)

        assert means.ndim == 5
        assert logvar.shape == means.shape
        assert means.size(1) == self.latent_dim

        return LatentDistribution(means, logvar)


class VideoVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(
            in_channels=15,
            base_channels=64,
            channel_multipliers=[1, 2, 4, 6],
            num_res_blocks=[3, 3, 4, 6, 3],
            latent_dim=12,
            temporal_reductions=[1, 2, 3],
            spatial_reductions=[2, 2, 2],
            prune_bottlenecks=[False, False, False, False, False],
            has_attentions=[False, True, True, True, True],
            affine=True,
            bias=True,
            input_is_conv_1x1=True,
            padding_mode="replicate"
        )
        self.decoder = Decoder(
            out_channels=3,
            base_channels=128,
            channel_multipliers=[1, 2, 4, 6],
            temporal_expansions=[1, 2, 3],
            spatial_expansions=[2, 2, 2],
            num_res_blocks=[3, 3, 4, 6, 3],
            latent_dim=12,
            has_attention=[False, False, False, False, False],
            padding_mode="replicate",
            output_norm=False,
            nonlinearity="silu",
            output_nonlinearity="silu",
            causal=True,
        )

    def encode(self, x):
        return self.encoder(x).mode()

    def decode(self, x):
        return self.decoder(x)