Tolerate missing @torch.library.custom_op (#7234)

This can happen on Pytorch versions older than 2.4.
This commit is contained in:
FeepingCreature 2025-03-14 14:51:26 +01:00 committed by GitHub
parent 7aceb9f91c
commit 9c98c6358b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -503,16 +503,23 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out return out
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) try:
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake @flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q # Output shape is the same as q
return q.new_empty(q.shape) return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):