From 2fd9c1308a2c868459e5c9b70ad43b7085974d81 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 02:10:09 -0500 Subject: [PATCH] Fix mask issue in some attention functions. --- comfy/ldm/modules/attention.py | 5 ++++- comfy/ldm/modules/sub_quadratic_attention.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7b4ee215..885b2401 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if len(mask.shape) == 2: s1 += mask[i:end] else: - s1 += mask[:, i:end] + if mask.shape[1] == 1: + s1 += mask + else: + s1 += mask[:, i:end] s2 = s1.softmax(dim=-1).to(v.dtype) del s1 diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 1bc4138c..47b8b151 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -234,6 +234,8 @@ def efficient_dot_product_attention( def get_mask_chunk(chunk_idx: int) -> Tensor: if mask is None: return None + if mask.shape[1] == 1: + return mask chunk = min(query_chunk_size, q_tokens) return mask[:,chunk_idx:chunk_idx + chunk]