mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
allow passing attention mask in flux attention
This commit is contained in:
parent
927418464e
commit
8954dafb44
@ -4,11 +4,11 @@ from torch import Tensor
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
if mask is not None:
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
# we assume the mask is either (B, M, N) or (M, N)
|
||||
# this way we avoid allocating a huge
|
||||
# this way we avoid allocating a huge matrix
|
||||
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
|
||||
# if skip_reshape, then q, k, v have merged heads and batch size
|
||||
if skip_reshape:
|
||||
|
Loading…
Reference in New Issue
Block a user