mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
fix attention OOM in xformers
This commit is contained in:
parent
9a616b81c1
commit
927418464e
@ -376,7 +376,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
# we assume the mask is either (B, M, N) or (M, N)
|
||||
# this way we avoid allocating a huge
|
||||
mask_batch_size = mask.shape[0] if len(mask.shape) == 3 else 1
|
||||
# if skip_reshape, then q, k, v have merged heads and batch size
|
||||
if skip_reshape:
|
||||
mask_out = torch.empty([mask_batch_size, q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
# otherwise, we have separate heads and batch size
|
||||
else:
|
||||
mask_out = torch.empty([mask_batch_size, q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user