diff --git a/comfy/model_management.py b/comfy/model_management.py index 212ce9af2..dd8a2a28f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -236,6 +236,19 @@ try: except: pass + +try: + if is_amd(): + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info("AMD arch: {}".format(arch)) + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much + if arch in ["gfx1100"]: #TODO: more arches + ENABLE_PYTORCH_ATTENTION = True +except: + pass + + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True)