fix attention OOM in xformers

This commit is contained in:
Raphael Walker 2024-12-05 11:10:07 +01:00
parent 9a616b81c1
commit 927418464e

View File

@ -376,7 +376,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
if mask is not None:
pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
# we assume the mask is either (B, M, N) or (M, N)
# this way we avoid allocating a huge
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
# if skip_reshape, then q, k, v have merged heads and batch size
if skip_reshape:
mask_out = torch.empty([mask_batch_size, q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
# otherwise, we have separate heads and batch size
else:
mask_out = torch.empty([mask_batch_size, q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
mask = mask_out[..., :mask.shape[-1]]