From 9c98c6358be2c7896de1547490bc87c9ad7a1ecb Mon Sep 17 00:00:00 2001 From: FeepingCreature <540727+FeepingCreature@users.noreply.github.com> Date: Fri, 14 Mar 2025 14:51:26 +0100 Subject: [PATCH] Tolerate missing `@torch.library.custom_op` (#7234) This can happen on Pytorch versions older than 2.4. --- comfy/ldm/modules/attention.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3e5089a6..7908d131 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -503,16 +503,23 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= return out -@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) -def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: - return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) +try: + @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) -@flash_attn_wrapper.register_fake -def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): - # Output shape is the same as q - return q.new_empty(q.shape) + @flash_attn_wrapper.register_fake + def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): + # Output shape is the same as q + return q.new_empty(q.shape) +except AttributeError as error: + FLASH_ATTN_ERROR = error + + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):