Improve AMD arch detection.

This commit is contained in:
comfyanonymous 2025-02-17 04:53:40 -05:00
parent 8c0bae50c3
commit 31e54b7052

View File

@ -245,7 +245,7 @@ try:
logging.info("AMD arch: {}".format(arch)) logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if arch in ["gfx1100"]: #TODO: more arches if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
pass pass
@ -1110,7 +1110,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName arch = torch.cuda.get_device_properties(device).gcnArchName
if arch in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]: # RDNA2 and older don't support bf16 if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if manual_cast: if manual_cast:
return True return True
return False return False