diff --git a/comfy/model_management.py b/comfy/model_management.py index f6dfc18b..003a89f5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -128,6 +128,19 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +def get_all_torch_devices(exclude_current=False): + global cpu_state + devices = [] + if cpu_state == CPUState.GPU: + if is_nvidia(): + for i in range(torch.cuda.device_count()): + devices.append(torch.device(i)) + else: + devices.append(get_torch_device()) + if exclude_current: + devices.remove(get_torch_device()) + return devices + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: