diff --git a/comfy/model_management.py b/comfy/model_management.py index 003a89f5..87ad290d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import psutil import logging @@ -26,6 +27,11 @@ import platform import weakref import gc +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -330,7 +336,7 @@ def module_size(module): return module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelPatcher): self._set_model(model) self.device = model.load_device self.real_model = None @@ -338,7 +344,7 @@ class LoadedModel: self.model_finalizer = None self._patcher_finalizer = None - def _set_model(self, model): + def _set_model(self, model: ModelPatcher): self._model = weakref.ref(model) if model.parent is not None: self._parent_model = weakref.ref(model.parent) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0501f7b3..5465dde6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -218,6 +218,8 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.is_multigpu_clone = False + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -293,6 +295,8 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.is_multigpu_clone = self.is_multigpu_clone + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n diff --git a/comfy/samplers.py b/comfy/samplers.py index 5cc33a7d..f3064000 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -19,6 +19,7 @@ import comfy.patcher_extension import comfy.hooks import scipy.stats import numpy +import threading def get_area_and_mult(conds, x_in, timestep_in): dims = tuple(x_in.shape[2:]) @@ -130,7 +131,7 @@ def can_concat_cond(c1, c2): return cond_equal_size(c1.conditioning, c2.conditioning) -def cond_cat(c_list): +def cond_cat(c_list, device=None): temp = {} for x in c_list: for k in x: @@ -142,6 +143,8 @@ def cond_cat(c_list): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) + if device is not None: + out[k] = out[k].to(device) return out @@ -195,7 +198,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + if 'multigpu_clones' in model_options: + return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] # separate conds by matching hooks @@ -329,6 +334,173 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te return out_conds +def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + out_conds = [] + out_counts = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False + + output_device = x_in.device + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + default_c = [] + if cond is not None: + for x in cond: + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + if p is None: + continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + + model.current_patcher.prepare_state(timestep) + + devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()] + device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + count = 0 + # run every hooked_to_run separately + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + current_device = devices[count % len(devices)] + free_memory = model_management.get_free_memory(current_device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + # if model.memory_required(input_shape) * 1.5 < free_memory: + # to_batch = batch_amount + # break + conds_to_batch = [] + for x in to_batch: + conds_to_batch.append(to_run.pop(x)) + + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + batched_to_run.append((hooks, conds_to_batch)) + count += 1 + + thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + + c['transformer_options'] = transformer_options + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + + + results: list[thread_result] = [] + threads: list[threading.Thread] = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) + threads.append(new_thread) + new_thread.start() + + for thread in threads: + thread.join() + + for output, mult, area, batch_chunks, cond_or_uncond in results: + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) @@ -940,6 +1112,14 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device + multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[device] = self.model_patcher + for x in multigpu_patchers: + multigpu_dict[x.load_device] = x + self.model_options["multigpu_clones"] = multigpu_dict + if denoise_mask is not None: denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) @@ -950,9 +1130,13 @@ class CFGGuider: try: self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model