mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Fix pytorch 2.0 cross attention not working.
This commit is contained in:
parent
f9d09c266f
commit
798c90e1c0
@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
|
|||||||
if "--use-pytorch-cross-attention" in sys.argv:
|
if "--use-pytorch-cross-attention" in sys.argv:
|
||||||
print("Using pytorch cross attention")
|
print("Using pytorch cross attention")
|
||||||
torch.backends.cuda.enable_math_sdp(False)
|
torch.backends.cuda.enable_math_sdp(False)
|
||||||
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
CrossAttention = CrossAttentionPytorch
|
CrossAttention = CrossAttentionPytorch
|
||||||
else:
|
else:
|
||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
@ -497,6 +499,7 @@ else:
|
|||||||
print("Using xformers cross attention")
|
print("Using xformers cross attention")
|
||||||
CrossAttention = MemoryEfficientCrossAttention
|
CrossAttention = MemoryEfficientCrossAttention
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False):
|
disable_self_attn=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user