Make buggy xformers fall back on pytorch attention.

This commit is contained in:
comfyanonymous 2023-11-24 03:55:35 -05:00
parent 982338b9bb
commit 3e5ea74ad3

View File

@ -278,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
)
return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)