From 19ee5d9d8bc3cc4754ee5c8f63af4b3e37714a5b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 16 Dec 2024 18:22:50 -0500 Subject: [PATCH] Don't expand mask when not necessary. Expanding seems to slow down inference. --- comfy/ldm/modules/attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 54b42a4b..e60d1ab2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -423,8 +423,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha # add a heads dimension if there isn't already one if mask.ndim == 3: mask = mask.unsqueeze(1) - mask = mask.expand(b, heads, -1, -1) - if SDP_BATCH_LIMIT >= b: out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) @@ -434,11 +432,16 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha else: out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) for i in range(0, b, SDP_BATCH_LIMIT): + m = mask + if mask is not None: + if mask.shape[0] > 1: + m = mask[i : i + SDP_BATCH_LIMIT] + out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], - attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT], + attn_mask=m, dropout_p=0.0, is_causal=False ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out