Refactor torch version checks to be more future proof.

This commit is contained in:
comfyanonymous 2025-02-17 04:36:45 -05:00
parent 61c8c70c6e
commit 530412cb9d

View File

@ -50,7 +50,9 @@ xpu_available = False
torch_version = "" torch_version = ""
try: try:
torch_version = torch.version.__version__ torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available() temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except: except:
pass pass
@ -227,7 +229,7 @@ if args.use_pytorch_cross_attention:
try: try:
if is_nvidia(): if is_nvidia():
if int(torch_version[0]) >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu(): if is_intel_xpu() or is_ascend_npu():
@ -242,7 +244,7 @@ try:
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch)) logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if arch in ["gfx1100"]: #TODO: more arches if arch in ["gfx1100"]: #TODO: more arches
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
@ -261,7 +263,7 @@ except:
pass pass
try: try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5: if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except: except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@ -1136,11 +1138,11 @@ def supports_fp8_compute(device=None):
if props.minor < 9: if props.minor < 9:
return False return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3): if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
return False return False
if WINDOWS: if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4): if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
return False return False
return True return True