diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py new file mode 100644 index 00000000..929151b5 --- /dev/null +++ b/comfy_extras/nodes_multigpu.py @@ -0,0 +1,39 @@ +from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management +import copy + + +class MultiGPUInitialize: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "init_multigpu" + CATEGORY = "DevTools" + + def init_multigpu(self, model: ModelPatcher): + model = model.clone() + extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + if len(extra_devices) > 0: + comfy.model_management.unload_all_models() + for device in extra_devices: + device_patcher = model.clone() + device_patcher.model = copy.deepcopy(model.model) + device_patcher.load_device = device + device_patcher.is_multigpu_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + return (model,) + + +NODE_CLASS_MAPPINGS = { + "test_multigpuinit": MultiGPUInitialize, +} \ No newline at end of file