diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 59683f645..692952f32 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-pytorch-cross-attention" in sys.argv: print("Using pytorch cross attention") torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) CrossAttention = CrossAttentionPytorch else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") @@ -497,6 +499,7 @@ else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):