Squash depreciation warning on new pytorch.

This commit is contained in:
comfyanonymous 2024-06-16 13:06:23 -04:00
parent ca9d300a80
commit 8ddc151a4c

View File

@ -75,10 +75,16 @@ class SnakeBeta(nn.Module):
return x
def WNConv1d(*args, **kwargs):
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs))
try:
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
def WNConvTranspose1d(*args, **kwargs):
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
try:
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":