diff --git a/comfy/model_management.py b/comfy/model_management.py index 87ad290d..2cf792b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1146,6 +1146,16 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def unload_model_and_clones(model: ModelPatcher): + 'Unload only model and its clones - primarily for multigpu cloning purposes.' + initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + keep_loaded = [] + for loaded_model in initial_keep_loaded: + if loaded_model.model is not None: + if model.clone_base_uuid == loaded_model.model.clone_base_uuid: + continue + keep_loaded.append(loaded_model) + free_memory(1e30, get_torch_device(), keep_loaded) #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 46779397..b4efa8d0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -222,7 +222,7 @@ class ModelPatcher: self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.is_multigpu_clone = False - self.clone_uuid = uuid.uuid4() + self.clone_base_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -300,7 +300,7 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.is_multigpu_clone = self.is_multigpu_clone - n.clone_uuid = self.clone_uuid + n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index a95231ff..5564b62c 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -173,6 +173,7 @@ def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_mo multigpu_dict[model_patcher.load_device] = model_patcher for x in multigpu_patchers: x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + x.hook_mode = model_patcher.hook_mode # match main model's hook_mode multigpu_dict[x.load_device] = x model_options["multigpu_clones"] = multigpu_dict return multigpu_patchers diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index b3c8635b..b5c36c64 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -26,7 +26,7 @@ class MultiGPUInitialize: extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: model = model.clone() - comfy.model_management.unload_all_models() + comfy.model_management.unload_model_and_clones(model) for device in extra_devices: device_patcher = model.multigpu_deepclone(new_load_device=device) device_patcher.is_multigpu_clone = True