Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups

This commit is contained in:
Jedrzej Kosinski 2025-02-06 08:44:07 -06:00
parent 441cfd1a7a
commit 476aa79b64
2 changed files with 13 additions and 3 deletions

View File

@ -50,7 +50,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")

View File

@ -141,6 +141,12 @@ def get_all_torch_devices(exclude_current=False):
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
@ -320,10 +326,14 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device [ ]: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models = []