Fix regression in VAE code on old pytorch versions.

This commit is contained in:
comfyanonymous 2024-12-18 03:08:28 -05:00
parent 79badea452
commit 4c5c4ddeda

View File

@ -91,7 +91,7 @@ class Upsample(nn.Module):
def forward(self, x):
scale_factor = self.scale_factor
if not isinstance(scale_factor, tuple):
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
@ -109,7 +109,7 @@ class Upsample(nn.Module):
else:
x = a
else:
x = interpolate_up(x, self.scale_factor)
x = interpolate_up(x, scale_factor)
if self.with_conv:
x = self.conv(x)
return x