From 530412cb9da671d1e191ca19b0df86c5bb252a62 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Feb 2025 04:36:45 -0500 Subject: [PATCH] Refactor torch version checks to be more future proof. --- comfy/model_management.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9f6522967..05f66c9e5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -50,7 +50,9 @@ xpu_available = False torch_version = "" try: torch_version = torch.version.__version__ - xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available() + temp = torch_version.split(".") + torch_version_numeric = (int(temp[0]), int(temp[1])) + xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() except: pass @@ -227,7 +229,7 @@ if args.use_pytorch_cross_attention: try: if is_nvidia(): - if int(torch_version[0]) >= 2: + if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True if is_intel_xpu() or is_ascend_npu(): @@ -242,7 +244,7 @@ try: arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName logging.info("AMD arch: {}".format(arch)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much + if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much if arch in ["gfx1100"]: #TODO: more arches ENABLE_PYTORCH_ATTENTION = True except: @@ -261,7 +263,7 @@ except: pass try: - if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5: + if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) except: logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") @@ -1136,11 +1138,11 @@ def supports_fp8_compute(device=None): if props.minor < 9: return False - if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3): + if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3): return False if WINDOWS: - if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4): + if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4): return False return True