mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Support attention mask in split attention.
This commit is contained in:
parent
3ad0191bfb
commit
0c2c9fbdfa
@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
else:
|
else:
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if len(mask.shape) == 2:
|
||||||
|
s1 += mask[i:end]
|
||||||
|
else:
|
||||||
|
s1 += mask[:, i:end]
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
del s1
|
del s1
|
||||||
first_op_done = True
|
first_op_done = True
|
||||||
|
Loading…
Reference in New Issue
Block a user