diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 4d43feb2..e0344dee 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -378,7 +378,7 @@ class Decoder(nn.Module): assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" - scaled_timestep = timestep * self.timestep_scale_multiplier + scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) for up_block in self.up_blocks: if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): @@ -403,7 +403,7 @@ class Decoder(nn.Module): ) ada_values = self.last_scale_shift_table[ None, ..., None, None, None - ] + embedded_timestep.reshape( + ].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( batch_size, 2, -1, @@ -697,7 +697,7 @@ class ResnetBlock3D(nn.Module): ), "should pass timestep with timestep_conditioning=True" ada_values = self.scale_shift_table[ None, ..., None, None, None - ] + timestep.reshape( + ].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( batch_size, 4, -1, @@ -715,7 +715,7 @@ class ResnetBlock3D(nn.Module): if self.inject_noise: hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale1 + hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype) ) hidden_states = self.norm2(hidden_states) @@ -731,7 +731,7 @@ class ResnetBlock3D(nn.Module): if self.inject_noise: hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale2 + hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype) ) input_tensor = self.norm3(input_tensor)