diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 130ed6fd..976f98c6 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module): def forward(self, pixel_values): x = self.patch_embeddings(pixel_values) # TODO: mask_token? - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) return x