diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 929151b5..3ba55862 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,6 +11,9 @@ class MultiGPUInitialize: return { "required": { "model": ("MODEL",), + }, + "optional": { + "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), } } @@ -18,9 +21,10 @@ class MultiGPUInitialize: FUNCTION = "init_multigpu" CATEGORY = "DevTools" - def init_multigpu(self, model: ModelPatcher): + def init_multigpu(self, model: ModelPatcher, max_gpus: int): model = model.clone() extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: comfy.model_management.unload_all_models() for device in extra_devices: