mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor.
This commit is contained in:
parent
ca2812bae0
commit
412d3ff57d
24
comfy/ops.py
24
comfy/ops.py
@ -1,29 +1,23 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user