From 0c2c9fbdfa53c2ad3b7658a7f2300da831830388 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Jan 2024 13:16:48 -0500 Subject: [PATCH] Support attention mask in split attention. --- comfy/ldm/modules/attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 14d41a8c..a18a6929 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None): else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 first_op_done = True