mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-10 17:37:17 +08:00
allow passing attention mask in flux attention
This commit is contained in:
parent
927418464e
commit
8954dafb44
@ -4,11 +4,11 @@ from torch import Tensor
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||||
q, k = apply_rope(q, k, pe)
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
pad = 8 - mask.shape[-1] % 8
|
pad = 8 - mask.shape[-1] % 8
|
||||||
# we assume the mask is either (B, M, N) or (M, N)
|
# we assume the mask is either (B, M, N) or (M, N)
|
||||||
# this way we avoid allocating a huge
|
# this way we avoid allocating a huge matrix
|
||||||
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
|
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
|
||||||
# if skip_reshape, then q, k, v have merged heads and batch size
|
# if skip_reshape, then q, k, v have merged heads and batch size
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user