mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-12 02:45:16 +00:00
83 lines
2.1 KiB
Python
83 lines
2.1 KiB
Python
|
from typing import Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from .dual_conv3d import DualConv3d
|
||
|
from .causal_conv3d import CausalConv3d
|
||
|
|
||
|
|
||
|
def make_conv_nd(
|
||
|
dims: Union[int, Tuple[int, int]],
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size: int,
|
||
|
stride=1,
|
||
|
padding=0,
|
||
|
dilation=1,
|
||
|
groups=1,
|
||
|
bias=True,
|
||
|
causal=False,
|
||
|
):
|
||
|
if dims == 2:
|
||
|
return torch.nn.Conv2d(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=out_channels,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
padding=padding,
|
||
|
dilation=dilation,
|
||
|
groups=groups,
|
||
|
bias=bias,
|
||
|
)
|
||
|
elif dims == 3:
|
||
|
if causal:
|
||
|
return CausalConv3d(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=out_channels,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
padding=padding,
|
||
|
dilation=dilation,
|
||
|
groups=groups,
|
||
|
bias=bias,
|
||
|
)
|
||
|
return torch.nn.Conv3d(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=out_channels,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
padding=padding,
|
||
|
dilation=dilation,
|
||
|
groups=groups,
|
||
|
bias=bias,
|
||
|
)
|
||
|
elif dims == (2, 1):
|
||
|
return DualConv3d(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=out_channels,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
padding=padding,
|
||
|
bias=bias,
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||
|
|
||
|
|
||
|
def make_linear_nd(
|
||
|
dims: int,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
bias=True,
|
||
|
):
|
||
|
if dims == 2:
|
||
|
return torch.nn.Conv2d(
|
||
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||
|
)
|
||
|
elif dims == 3 or dims == (2, 1):
|
||
|
return torch.nn.Conv3d(
|
||
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"unsupported dimensions: {dims}")
|