Add get_all_torch_devices to get detected devices intended for current torch hardware device

This commit is contained in:
Jedrzej Kosinski 2025-01-07 21:06:03 -06:00
parent 66838ebd39
commit 871258aa72

View File

@ -128,6 +128,19 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None: