From 798c90e1c094d86bff6c1eca51c0ed4dedea1871 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 5 Mar 2023 14:14:54 -0500 Subject: [PATCH] Fix pytorch 2.0 cross attention not working. --- comfy/ldm/modules/attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 59683f645..692952f32 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-pytorch-cross-attention" in sys.argv: print("Using pytorch cross attention") torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) CrossAttention = CrossAttentionPytorch else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") @@ -497,6 +499,7 @@ else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):