Add activations_shape info in UNet models (#7482)

* Add activations_shape info in UNet models

* activations_shape should be a list
This commit is contained in:
Raphael Walker 2025-04-05 03:27:54 +02:00 committed by GitHub
parent 3a100b9a55
commit 89e4ea0175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -847,6 +847,7 @@ class SpatialTransformer(nn.Module):
if not isinstance(context, list):
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
x = self.norm(x)
if not self.use_linear:
@ -962,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer):
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
spatial_context = None
if exists(context):