mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Fix lowvram issue with ltxv vae.
This commit is contained in:
parent
57f330caf9
commit
80f07952d2
@ -378,7 +378,7 @@ class Decoder(nn.Module):
|
|||||||
assert (
|
assert (
|
||||||
timestep is not None
|
timestep is not None
|
||||||
), "should pass timestep with timestep_conditioning=True"
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||||
|
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
@ -403,7 +403,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
ada_values = self.last_scale_shift_table[
|
ada_values = self.last_scale_shift_table[
|
||||||
None, ..., None, None, None
|
None, ..., None, None, None
|
||||||
] + embedded_timestep.reshape(
|
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
|
||||||
batch_size,
|
batch_size,
|
||||||
2,
|
2,
|
||||||
-1,
|
-1,
|
||||||
@ -697,7 +697,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
), "should pass timestep with timestep_conditioning=True"
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
ada_values = self.scale_shift_table[
|
ada_values = self.scale_shift_table[
|
||||||
None, ..., None, None, None
|
None, ..., None, None, None
|
||||||
] + timestep.reshape(
|
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
|
||||||
batch_size,
|
batch_size,
|
||||||
4,
|
4,
|
||||||
-1,
|
-1,
|
||||||
@ -715,7 +715,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
|
|
||||||
if self.inject_noise:
|
if self.inject_noise:
|
||||||
hidden_states = self._feed_spatial_noise(
|
hidden_states = self._feed_spatial_noise(
|
||||||
hidden_states, self.per_channel_scale1
|
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.norm2(hidden_states)
|
hidden_states = self.norm2(hidden_states)
|
||||||
@ -731,7 +731,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
|
|
||||||
if self.inject_noise:
|
if self.inject_noise:
|
||||||
hidden_states = self._feed_spatial_noise(
|
hidden_states = self._feed_spatial_noise(
|
||||||
hidden_states, self.per_channel_scale2
|
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
input_tensor = self.norm3(input_tensor)
|
input_tensor = self.norm3(input_tensor)
|
||||||
|
Loading…
Reference in New Issue
Block a user