diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 98dbda63..c27d032a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module): b, _, _ = q.shape q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), + lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), (q, k, v), ) @@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module): if exists(mask): raise NotImplementedError out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) + out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) ) return self.to_out(out)