mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-16 06:27:15 +00:00
Add --use-flash-attention flag. (#7223)
* Add --use-flash-attention flag. This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
This commit is contained in:
parent
35504e2f93
commit
7aceb9f91c
@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
|||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
|
|||||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
if model_management.flash_attention_enabled():
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -496,6 +503,56 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
|
@flash_attn_wrapper.register_fake
|
||||||
|
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||||
|
# Output shape is the same as q
|
||||||
|
return q.new_empty(q.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert mask is None
|
||||||
|
out = flash_attn_wrapper(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
).transpose(1, 2)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
@ -504,6 +561,9 @@ if model_management.sage_attention_enabled():
|
|||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
|
elif model_management.flash_attention_enabled():
|
||||||
|
logging.info("Using Flash Attention")
|
||||||
|
optimized_attention = attention_flash
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
|
@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
def flash_attention_enabled():
|
||||||
|
return args.use_flash_attention
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
|
Loading…
Reference in New Issue
Block a user