diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 63f1f92e..46779397 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) -def create_hook_patches_clone(orig_hook_patches): +def create_hook_patches_clone(orig_hook_patches, copy_tuples=False): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + if copy_tuples: + for i in range(len(new_hook_patches[hook_ref][k])): + new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i]) return new_hook_patches def wipe_lowvram_weight(m): @@ -303,7 +306,7 @@ class ModelPatcher: callback(self, n) return n - def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): n = self.clone() # set load device, if present if new_load_device is not None: @@ -312,6 +315,7 @@ class ModelPatcher: # otherwise, patchers that have deep copies of base models will erroneously influence each other. n.backup = copy.deepcopy(n.backup) n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.hook_backup = copy.deepcopy(n.hook_backup) n.model = copy.deepcopy(n.model) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") @@ -322,7 +326,7 @@ class ModelPatcher: for i in range(len(model_list)): add_model = n.additional_models[key][i] if i not in models_cache: - models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache) + models_cache[add_model] = add_model.multigpu_deepclone(new_load_device=new_load_device, models_cache=models_cache) n.additional_models[key][i] = models_cache[add_model] return n @@ -952,9 +956,13 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep): + def prepare_state(self, timestep, model_options, ignore_multigpu=False): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) + callback(self, timestep, model_options, ignore_multigpu) + if not ignore_multigpu and "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options, ignore_multigpu=True) def restore_hook_patches(self): if self.hook_patches_backup is not None: @@ -967,12 +975,18 @@ class ModelPatcher: def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + multigpu_kf_changed_cache = None transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: + # cache changed for multigpu usage + if "multigpu_clones" in model_options: + if multigpu_kf_changed_cache is None: + multigpu_kf_changed_cache = [] + multigpu_kf_changed_cache.append(hook) # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: @@ -984,6 +998,28 @@ class ModelPatcher: self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) + if "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p._handle_changed_hook_keyframes(multigpu_kf_changed_cache) + + def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]): + 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.' + if kf_changed_cache is None: + return + reset_current_hooks = False + # reset current_hooks if contains hook that changed + for hook in kf_changed_cache: + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index b70e5e63..a95231ff 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,7 +1,9 @@ from __future__ import annotations +import torch import uuid import comfy.model_management import comfy.conds +import comfy.model_patcher import comfy.utils import comfy.hooks import comfy.patcher_extension @@ -127,7 +129,7 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): +def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): ''' Registers hooks from conds. ''' @@ -160,3 +162,17 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options + +def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict): + ''' + In case multigpu acceleration is enabled, prep ModelPatchers for each device. + ''' + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[model_patcher.load_device] = model_patcher + for x in multigpu_patchers: + x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + multigpu_dict[x.load_device] = x + model_options["multigpu_clones"] = multigpu_dict + return multigpu_patchers diff --git a/comfy/samplers.py b/comfy/samplers.py index e9cd076e..dde0b652 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -232,7 +232,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -368,39 +368,53 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) - devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()] + devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} - count = 0 + + total_conds = 0 + for to_run in hooked_to_run.values(): + total_conds += len(to_run) + conds_per_device = max(1, math.ceil(total_conds//len(devices))) + index_device = 0 + current_device = devices[index_device] # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): while len(to_run) > 0: + current_device = devices[index_device % len(devices)] + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + # keep track of conds currently scheduled onto this device + batched_to_run_length = 0 + for btr in batched_to_run: + batched_to_run_length += len(btr[1]) + first = to_run[0] first_shape = first[0][0].shape to_batch_temp = [] + # make sure not over conds_per_device limit when creating temp batch for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length): to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] - current_device = devices[count % len(devices)] free_memory = model_management.get_free_memory(current_device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - # if model.memory_required(input_shape) * 1.5 < free_memory: - # to_batch = batch_amount - # break + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break conds_to_batch = [] for x in to_batch: conds_to_batch.append(to_run.pop(x)) - - batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + batched_to_run_length += len(conds_to_batch) + batched_to_run.append((hooks, conds_to_batch)) - count += 1 + if batched_to_run_length >= conds_per_device: + index_device += 1 thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): @@ -1112,13 +1126,7 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone] - if len(multigpu_patchers) > 0: - multigpu_dict: dict[torch.device, ModelPatcher] = {} - multigpu_dict[device] = self.model_patcher - for x in multigpu_patchers: - multigpu_dict[x.load_device] = x - self.model_options["multigpu_clones"] = multigpu_dict + multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) if denoise_mask is not None: denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index dec395fb..b3c8635b 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -28,7 +28,7 @@ class MultiGPUInitialize: model = model.clone() comfy.model_management.unload_all_models() for device in extra_devices: - device_patcher = model.multigpu_clone(new_load_device=device) + device_patcher = model.multigpu_deepclone(new_load_device=device) device_patcher.is_multigpu_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher)