diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a92fc0db..f54be19e 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -50,7 +50,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 420eb9e8..477bb0f5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -141,6 +141,12 @@ def get_all_torch_devices(exclude_current=False): if is_nvidia(): for i in range(torch.cuda.device_count()): devices.append(torch.device(i)) + elif is_intel_xpu(): + for i in range(torch.xpu.device_count()): + devices.append(torch.device(i)) + elif is_ascend_npu(): + for i in range(torch.npu.device_count()): + devices.append(torch.device(i)) else: devices.append(get_torch_device()) if exclude_current: @@ -320,10 +326,14 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) + logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") - +try: + for device in get_all_torch_devices(exclude_current=True): + logging.info("Device [ ]: {}".format(get_torch_device_name(device))) +except: + pass current_loaded_models = []