Update symmetric_patchifier.py

fix torch.compile error with meshgrid being called without indexing kwarg
This commit is contained in:
Miklos Nagy 2024-12-20 19:29:20 +01:00 committed by GitHub
parent 418eb7062d
commit ca9b1b18e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,7 +53,7 @@ class Patchifier(ABC):
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing=None)
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)