Fix some OOM issues with split attention.

This commit is contained in:
comfyanonymous 2023-10-30 13:14:11 -04:00
parent 41b07ff8d7
commit 125b03eead

View File

@ -229,7 +229,7 @@ def attention_split(q, k, v, heads, mask=None):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3 if element_size == 2 else 2.5
modifier = 3
mem_required = tensor_size * modifier
steps = 1
@ -257,10 +257,10 @@ def attention_split(q, k, v, heads, mask=None):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2