Fix for dino lowvram. (#7748)

This commit is contained in:
comfyanonymous 2025-04-23 01:12:52 -07:00 committed by GitHub
parent 0738e4ea5d
commit 552615235d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x