fix xformers masks

This commit is contained in:
Raphael Walker 2024-12-05 17:40:41 +01:00
parent 74f8eaf7f0
commit 7739f5f8d9

View File

@ -376,15 +376,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
if mask is not None:
pad = 8 - mask.shape[-1] % 8
# we assume the mask is either (B, M, N) or (M, N)
# this way we avoid allocating a huge matrix
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
# if skip_reshape, then q, k, v have merged heads and batch size
if skip_reshape:
mask_out = torch.empty([mask_batch_size, q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out = torch.empty([q.shape[0], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
# otherwise, we have separate heads and batch size
else:
mask_out = torch.empty([mask_batch_size, q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
mask = mask_out[..., :mask.shape[-1]]