mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-18 01:53:31 +00:00
Initial work on multigpu_clone function, which will account for additional_models getting cloned
This commit is contained in:
parent
31f5458938
commit
bfce723311
@ -219,6 +219,7 @@ class ModelPatcher:
|
|||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
self.is_multigpu_clone = False
|
self.is_multigpu_clone = False
|
||||||
|
self.clone_uuid = uuid.uuid4()
|
||||||
|
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
@ -296,11 +297,35 @@ class ModelPatcher:
|
|||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
n.is_multigpu_clone = self.is_multigpu_clone
|
n.is_multigpu_clone = self.is_multigpu_clone
|
||||||
|
n.clone_uuid = self.clone_uuid
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None):
|
||||||
|
n = self.clone()
|
||||||
|
# set load device, if present
|
||||||
|
if new_load_device is not None:
|
||||||
|
n.load_device = new_load_device
|
||||||
|
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||||
|
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||||
|
n.backup = copy.deepcopy(n.backup)
|
||||||
|
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||||
|
n.model = copy.deepcopy(n.model)
|
||||||
|
# multigpu clone should not have multigpu additional_models entry
|
||||||
|
n.remove_additional_models("multigpu")
|
||||||
|
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||||
|
if models_cache is None:
|
||||||
|
models_cache = {}
|
||||||
|
for key, model_list in n.additional_models.items():
|
||||||
|
for i in range(len(model_list)):
|
||||||
|
add_model = n.additional_models[key][i]
|
||||||
|
if i not in models_cache:
|
||||||
|
models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache)
|
||||||
|
n.additional_models[key][i] = models_cache[add_model]
|
||||||
|
return n
|
||||||
|
|
||||||
def is_clone(self, other):
|
def is_clone(self, other):
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
if hasattr(other, 'model') and self.model is other.model:
|
||||||
return True
|
return True
|
||||||
|
@ -22,15 +22,13 @@ class MultiGPUInitialize:
|
|||||||
CATEGORY = "DevTools"
|
CATEGORY = "DevTools"
|
||||||
|
|
||||||
def init_multigpu(self, model: ModelPatcher, max_gpus: int):
|
def init_multigpu(self, model: ModelPatcher, max_gpus: int):
|
||||||
model = model.clone()
|
|
||||||
extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||||
extra_devices = extra_devices[:max_gpus-1]
|
extra_devices = extra_devices[:max_gpus-1]
|
||||||
if len(extra_devices) > 0:
|
if len(extra_devices) > 0:
|
||||||
|
model = model.clone()
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
for device in extra_devices:
|
for device in extra_devices:
|
||||||
device_patcher = model.clone()
|
device_patcher = model.multigpu_clone(new_load_device=device)
|
||||||
device_patcher.model = copy.deepcopy(model.model)
|
|
||||||
device_patcher.load_device = device
|
|
||||||
device_patcher.is_multigpu_clone = True
|
device_patcher.is_multigpu_clone = True
|
||||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
multigpu_models.append(device_patcher)
|
multigpu_models.append(device_patcher)
|
||||||
|
Loading…
Reference in New Issue
Block a user