allow passing attention mask in flux attention

This commit is contained in:
Raphael Walker 2024-12-05 11:12:54 +01:00
parent 927418464e
commit 8954dafb44
2 changed files with 3 additions and 3 deletions

View File

@ -4,11 +4,11 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x

View File

@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
if mask is not None:
pad = 8 - mask.shape[-1] % 8
# we assume the mask is either (B, M, N) or (M, N)
# this way we avoid allocating a huge
# this way we avoid allocating a huge matrix
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
# if skip_reshape, then q, k, v have merged heads and batch size
if skip_reshape: