From 093914a24714ef7264e34062fdeae46bd81964d9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 3 Mar 2025 22:56:13 -0600 Subject: [PATCH] Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code --- comfy/model_management.py | 14 ++++--- comfy/model_patcher.py | 67 +++++++++++++++++++++++++++++----- comfy/multigpu.py | 52 ++++++++++++++++++++++++++ comfy/patcher_extension.py | 2 + comfy/sampler_helpers.py | 47 ++++++++++++++++++++++-- comfy/samplers.py | 44 ---------------------- comfy_extras/nodes_multigpu.py | 49 ++++++++++++------------- 7 files changed, 188 insertions(+), 87 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dd762bdc..3ee8857c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -345,16 +345,16 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device()))) + logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") try: for device in get_all_torch_devices(exclude_current=True): - logging.info("Device [ ]: {}".format(get_torch_device_name(device))) + logging.info("Device: {}".format(get_torch_device_name(device))) except: pass -current_loaded_models = [] +current_loaded_models: list[LoadedModel] = [] def module_size(module): module_mem = 0 @@ -1198,7 +1198,7 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) -def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): 'Unload only model and its clones - primarily for multigpu cloning purposes.' initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() additional_models = [] @@ -1218,7 +1218,11 @@ def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): if skip: continue keep_loaded.append(loaded_model) - free_memory(1e30, get_torch_device(), keep_loaded) + if not all_devices: + free_memory(1e30, get_torch_device(), keep_loaded) + else: + for device in get_all_torch_devices(): + free_memory(1e30, device, keep_loaded) #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index eb21396b..5ede41dd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -243,7 +243,7 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed - self.is_multigpu_clone = False + self.is_multigpu_base_clone = False self.clone_base_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): @@ -324,14 +324,16 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode - n.is_multigpu_clone = self.is_multigpu_clone + n.is_multigpu_base_clone = self.is_multigpu_base_clone n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n - def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): + logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + comfy.model_management.unload_model_and_clones(self) n = self.clone() # set load device, if present if new_load_device is not None: @@ -350,19 +352,64 @@ class ModelPatcher: for key, model_list in n.additional_models.items(): for i in range(len(model_list)): add_model = n.additional_models[key][i] - if i not in models_cache: - models_cache[add_model] = add_model.multigpu_deepclone(new_load_device=new_load_device, models_cache=models_cache) - n.additional_models[key][i] = models_cache[add_model] + if add_model.clone_base_uuid not in models_cache: + models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model.clone_base_uuid] + for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU): + callback(self, n) return n + def match_multigpu_clones(self): + multigpu_models = self.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + new_multigpu_models = [] + for mm in multigpu_models: + # clone main model, but bring over relevant props from existing multigpu clone + n = self.clone() + n.load_device = mm.load_device + n.backup = mm.backup + n.object_patches_backup = mm.object_patches_backup + n.hook_backup = mm.hook_backup + n.model = mm.model + n.is_multigpu_base_clone = mm.is_multigpu_base_clone + n.remove_additional_models("multigpu") + orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models) + n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models) + # figure out which additional models are not present in multigpu clone + models_cache = {} + for mm_add_model in mm.get_additional_models(): + models_cache[mm_add_model.clone_base_uuid] = mm_add_model + remove_models_uuids = set(list(models_cache.keys())) + for key, model_list in orig_additional_models.items(): + for orig_add_model in model_list: + if orig_add_model.clone_base_uuid not in models_cache: + models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache) + existing_list = n.get_additional_models_with_key(key) + existing_list.append(models_cache[orig_add_model.clone_base_uuid]) + n.set_additional_models(key, existing_list) + if orig_add_model.clone_base_uuid in remove_models_uuids: + remove_models_uuids.remove(orig_add_model.clone_base_uuid) + # remove duplicate additional models + for key, model_list in n.additional_models.items(): + new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids] + n.set_additional_models(key, new_model_list) + for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES): + callback(self, n) + new_multigpu_models.append(n) + self.set_additional_models("multigpu", new_multigpu_models) + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): - if not self.is_clone(clone): - return False + def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False): + if allow_multigpu: + if self.clone_base_uuid != clone.clone_base_uuid: + return False + else: + if not self.is_clone(clone): + return False if self.current_hooks != clone.current_hooks: return False @@ -957,7 +1004,7 @@ class ModelPatcher: return self.additional_models.get(key, []) def get_additional_models(self): - all_models = [] + all_models: list[ModelPatcher] = [] for models in self.additional_models.values(): all_models.extend(models) return all_models diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 2a1fc29d..9cc8a37f 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,10 +1,14 @@ from __future__ import annotations import torch +import logging from collections import namedtuple from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management class GPUOptions: @@ -53,6 +57,53 @@ class GPUOptionsGroup: model.model_options['multigpu_options'] = opts_dict +def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): + 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' + model = model.clone() + # check if multigpu is already prepared - get the load devices from them if possible to exclude + skip_devices = set() + multigpu_models = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + for mm in multigpu_models: + skip_devices.add(mm.load_device) + skip_devices = list(skip_devices) + + extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + extra_devices = extra_devices[:max_gpus-1] + # exclude skipped devices + for skip in skip_devices: + if skip in extra_devices: + extra_devices.remove(skip) + # create new deepclones + if len(extra_devices) > 0: + for device in extra_devices: + device_patcher = None + if reuse_loaded: + # check if there are any ModelPatchers currently loaded that could be referenced here after a clone + loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() + for lm in loaded_models: + if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device: + device_patcher = lm.clone() + logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") + break + if device_patcher is None: + device_patcher = model.deepclone_multigpu(new_load_device=device) + device_patcher.is_multigpu_base_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + model.match_multigpu_clones() + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) + else: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # persist skip_devices for use in sampling code + # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: + # model.model_options["multigpu_skip_devices"] = skip_devices + return model + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' @@ -84,6 +135,7 @@ def load_balance_devices(model_options: dict[str], total_work: int, return_idle_ completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] # calculate relative time spent by the devices waiting on each other after their work is completed idle_time = abs(min(completion_time) - max(completion_time)) + # if need to compare work idle time, need to normalize to a common total work if work_normalized: idle_time *= (work_normalized/total_work) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 85975824..5145855f 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -3,6 +3,8 @@ from typing import Callable class CallbacksMP: ON_CLONE = "on_clone" + ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu" + ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones" ON_LOAD = "on_load_after" ON_DETACH = "on_detach_after" ON_CLEANUP = "on_cleanup" diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 40b2021f..9a97c855 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -106,16 +106,57 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # potentially handle gligen - since not widely used, ignored for now def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): - real_model: BaseModel = None + model.match_multigpu_clones() + preprocess_multigpu_conds(conds, model, model_options) models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) - real_model = model.model + real_model: BaseModel = model.model return real_model, conds, models @@ -166,7 +207,7 @@ def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_mo ''' In case multigpu acceleration is enabled, prep ModelPatchers for each device. ''' - multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_clone] + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone] if len(multigpu_patchers) > 0: multigpu_dict: dict[torch.device, ModelPatcher] = {} multigpu_dict[model_patcher.load_device] = model_patcher diff --git a/comfy/samplers.py b/comfy/samplers.py index beef0b7e..d02627d8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1088,49 +1088,6 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype= for cast in casts: wc_list[i] = wc_list[i].to(cast) - -def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher): - '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' - multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") - if len(multigpu_models) == 0: - return - extra_devices = [x.load_device for x in multigpu_models] - # handle controlnets - controlnets: set[ControlBase] = set() - for k in conds: - for kk in conds[k]: - if 'control' in kk: - controlnets.add(kk['control']) - if len(controlnets) > 0: - # first, unload all controlnet clones - for cnet in list(controlnets): - cnet_models = cnet.get_models() - for cm in cnet_models: - comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) - - # next, make sure each controlnet has a deepclone for all relevant devices - for cnet in controlnets: - curr_cnet = cnet - while curr_cnet is not None: - for device in extra_devices: - if device not in curr_cnet.multigpu_clones: - curr_cnet.deepclone_multigpu(device, autoregister=True) - curr_cnet = curr_cnet.previous_controlnet - # since all device clones are now present, recreate the linked list for cloned cnets per device - for cnet in controlnets: - curr_cnet = cnet - while curr_cnet is not None: - prev_cnet = curr_cnet.previous_controlnet - for device in extra_devices: - device_cnet = curr_cnet.get_instance_for_device(device) - prev_device_cnet = None - if prev_cnet is not None: - prev_device_cnet = prev_cnet.get_instance_for_device(device) - device_cnet.set_previous_controlnet(prev_device_cnet) - curr_cnet = prev_cnet - # TODO: handle gligen - - class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -1173,7 +1130,6 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 54f68182..d1e458b7 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,15 +1,24 @@ from __future__ import annotations +import logging +from inspect import cleandoc -from comfy.model_patcher import ModelPatcher -import comfy.utils -import comfy.patcher_extension -import comfy.model_management +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher import comfy.multigpu -class MultiGPUInitialize: - NodeId = "MultiGPU_Initialize" - NodeName = "MultiGPU Initialize" +class MultiGPUWorkUnitsNode: + """ + Prepares model to have sampling accelerated via splitting work units. + + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. + """ + + NodeId = "MultiGPU_WorkUnits" + NodeName = "MultiGPU Work Units" @classmethod def INPUT_TYPES(cls): return { @@ -25,25 +34,17 @@ class MultiGPUInitialize: RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" + DESCRIPTION = cleandoc(__doc__) def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): - extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - extra_devices = extra_devices[:max_gpus-1] - if len(extra_devices) > 0: - model = model.clone() - comfy.model_management.unload_model_and_clones(model) - for device in extra_devices: - device_patcher = model.multigpu_deepclone(new_load_device=device) - device_patcher.is_multigpu_clone = True - multigpu_models = model.get_additional_models_with_key("multigpu") - multigpu_models.append(device_patcher) - model.set_additional_models("multigpu", multigpu_models) - if gpu_options is None: - gpu_options = comfy.multigpu.GPUOptionsGroup() - gpu_options.register(model) + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) return (model,) class MultiGPUOptionsNode: + """ + Select the relative speed of GPUs in the special case they have significantly different performance from one another. + """ + NodeId = "MultiGPU_Options" NodeName = "MultiGPU Options" @classmethod @@ -61,6 +62,7 @@ class MultiGPUOptionsNode: RETURN_TYPES = ("GPU_OPTIONS",) FUNCTION = "create_gpu_options" CATEGORY = "advanced/multigpu" + DESCRIPTION = cleandoc(__doc__) def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): if not gpu_options: @@ -74,7 +76,7 @@ class MultiGPUOptionsNode: node_list = [ - MultiGPUInitialize, + MultiGPUWorkUnitsNode, MultiGPUOptionsNode ] NODE_CLASS_MAPPINGS = {} @@ -83,6 +85,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {} for node in node_list: NODE_CLASS_MAPPINGS[node.NodeId] = node NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName - -# TODO: remove -NODE_CLASS_MAPPINGS["test_multigpuinit"] = MultiGPUInitialize