mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-16 06:27:15 +00:00
Tolerate missing @torch.library.custom_op
(#7234)
This can happen on Pytorch versions older than 2.4.
This commit is contained in:
parent
7aceb9f91c
commit
9c98c6358b
@ -503,6 +503,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
@ -513,6 +514,12 @@ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|||||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||||
# Output shape is the same as q
|
# Output shape is the same as q
|
||||||
return q.new_empty(q.shape)
|
return q.new_empty(q.shape)
|
||||||
|
except AttributeError as error:
|
||||||
|
FLASH_ATTN_ERROR = error
|
||||||
|
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
|
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user