mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
This makes pytorch2.0 attention perform a bit faster.
This commit is contained in:
parent
989acd769a
commit
6908f9c949
@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
||||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
if exists(mask):
|
if exists(mask):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
||||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
Loading…
Reference in New Issue
Block a user