Auto enable mem efficient attention on gfx1100 on pytorch nightly 2.7

I'm not not sure which arches are supported yet. If you see improvements in
memory usage while using --use-pytorch-cross-attention on your AMD GPU let
me know and I will add it to the list.
This commit is contained in:
comfyanonymous 2025-02-14 04:17:56 -05:00
parent 042a905c37
commit d7b4bf21a2

View File

@ -236,6 +236,19 @@ try:
except:
pass
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much
if arch in ["gfx1100"]: #TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)