mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 12:35:30 +08:00
Make xformers work with hypertile.
This commit is contained in:
parent
1443caf373
commit
9906e3efe3
@ -253,12 +253,14 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
return r2
|
return r2
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None):
|
def attention_xformers(q, k, v, heads, mask=None):
|
||||||
b, _, _ = q.shape
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.unsqueeze(3)
|
||||||
.reshape(b, t.shape[1], heads, -1)
|
.reshape(b, -1, heads, dim_head)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.reshape(b * heads, t.shape[1], -1)
|
.reshape(b * heads, -1, dim_head)
|
||||||
.contiguous(),
|
.contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
@ -270,9 +272,9 @@ def attention_xformers(q, k, v, heads, mask=None):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.unsqueeze(0)
|
||||||
.reshape(b, heads, out.shape[1], -1)
|
.reshape(b, heads, -1, dim_head)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.reshape(b, out.shape[1], -1)
|
.reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user