mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 18:35:17 +00:00
63 lines
1.7 KiB
Python
63 lines
1.7 KiB
Python
|
from typing import Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class CausalConv3d(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels,
|
||
|
out_channels,
|
||
|
kernel_size: int = 3,
|
||
|
stride: Union[int, Tuple[int]] = 1,
|
||
|
dilation: int = 1,
|
||
|
groups: int = 1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
self.in_channels = in_channels
|
||
|
self.out_channels = out_channels
|
||
|
|
||
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||
|
self.time_kernel_size = kernel_size[0]
|
||
|
|
||
|
dilation = (dilation, 1, 1)
|
||
|
|
||
|
height_pad = kernel_size[1] // 2
|
||
|
width_pad = kernel_size[2] // 2
|
||
|
padding = (0, height_pad, width_pad)
|
||
|
|
||
|
self.conv = nn.Conv3d(
|
||
|
in_channels,
|
||
|
out_channels,
|
||
|
kernel_size,
|
||
|
stride=stride,
|
||
|
dilation=dilation,
|
||
|
padding=padding,
|
||
|
padding_mode="zeros",
|
||
|
groups=groups,
|
||
|
)
|
||
|
|
||
|
def forward(self, x, causal: bool = True):
|
||
|
if causal:
|
||
|
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||
|
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||
|
)
|
||
|
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||
|
else:
|
||
|
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||
|
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||
|
)
|
||
|
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||
|
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||
|
)
|
||
|
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||
|
x = self.conv(x)
|
||
|
return x
|
||
|
|
||
|
@property
|
||
|
def weight(self):
|
||
|
return self.conv.weight
|