mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 14:21:20 +00:00
Initialize more unet weights as the right dtype.
This commit is contained in:
parent
e21d9ad445
commit
7bf89ba923
@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@ -221,7 +222,7 @@ class ResBlock(TimestepBlock):
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
@ -247,7 +248,7 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
||||
),
|
||||
)
|
||||
|
||||
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim),
|
||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user