diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5465dde6..63f1f92e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -219,6 +219,7 @@ class ModelPatcher: self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.is_multigpu_clone = False + self.clone_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -296,11 +297,35 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.is_multigpu_clone = self.is_multigpu_clone + n.clone_uuid = self.clone_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n + def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + n = self.clone() + # set load device, if present + if new_load_device is not None: + n.load_device = new_load_device + # unlike for normal clone, backup dicts that shared same ref should not; + # otherwise, patchers that have deep copies of base models will erroneously influence each other. + n.backup = copy.deepcopy(n.backup) + n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.model = copy.deepcopy(n.model) + # multigpu clone should not have multigpu additional_models entry + n.remove_additional_models("multigpu") + # multigpu_clone all stored additional_models; make sure circular references are properly handled + if models_cache is None: + models_cache = {} + for key, model_list in n.additional_models.items(): + for i in range(len(model_list)): + add_model = n.additional_models[key][i] + if i not in models_cache: + models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model] + return n + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 3ba55862..dec395fb 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -22,15 +22,13 @@ class MultiGPUInitialize: CATEGORY = "DevTools" def init_multigpu(self, model: ModelPatcher, max_gpus: int): - model = model.clone() extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: + model = model.clone() comfy.model_management.unload_all_models() for device in extra_devices: - device_patcher = model.clone() - device_patcher.model = copy.deepcopy(model.model) - device_patcher.load_device = device + device_patcher = model.multigpu_clone(new_load_device=device) device_patcher.is_multigpu_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher)