From 7aceb9f91c1c2b860c1a65ac93a64b3bad575794 Mon Sep 17 00:00:00 2001 From: FeepingCreature <540727+FeepingCreature@users.noreply.github.com> Date: Fri, 14 Mar 2025 08:22:41 +0100 Subject: [PATCH] Add --use-flash-attention flag. (#7223) * Add --use-flash-attention flag. This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention. --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 60 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 64 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a864205b..91c1fe70 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2758f950..3e5089a6 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -24,6 +24,13 @@ if model_management.sage_attention_enabled(): logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") exit(-1) +if model_management.flash_attention_enabled(): + try: + from flash_attn import flash_attn_func + except ModuleNotFoundError: + logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") + exit(-1) + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -496,6 +503,56 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= return out +@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) +def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) + + +@flash_attn_wrapper.register_fake +def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): + # Output shape is the same as q + return q.new_empty(q.shape) + + +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + try: + assert mask is None + out = flash_attn_wrapper( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + dropout_p=0.0, + causal=False, + ).transpose(1, 2) + except Exception as e: + logging.warning(f"Flash Attention failed, using default SDPA: {e}") + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -504,6 +561,9 @@ if model_management.sage_attention_enabled(): elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers +elif model_management.flash_attention_enabled(): + logging.info("Using Flash Attention") + optimized_attention = attention_flash elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention") optimized_attention = attention_pytorch diff --git a/comfy/model_management.py b/comfy/model_management.py index b6f4e2d1..2a9b022b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False): def sage_attention_enabled(): return args.use_sage_attention +def flash_attention_enabled(): + return args.use_flash_attention + def xformers_enabled(): global directml_enabled global cpu_state