Support attention mask in split attention.

This commit is contained in:
comfyanonymous 2024-01-06 13:16:48 -05:00
parent 3ad0191bfb
commit 0c2c9fbdfa

View File

@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None):
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
del s1 del s1
first_op_done = True first_op_done = True