Disable default weight values in unet conv2d for faster loading.

This commit is contained in:
comfyanonymous 2023-06-14 19:46:08 -04:00
parent 9d54066ebc
commit 21f04fe632
2 changed files with 7 additions and 3 deletions

View File

@ -16,7 +16,7 @@ import numpy as np
from einops import repeat
from comfy.ldm.util import instantiate_from_config
import comfy.ops
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
return comfy.ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
@ -243,7 +243,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
return comfy.ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):

View File

@ -15,3 +15,7 @@ class Linear(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias)
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None