mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-15 16:13:29 +00:00
Initial proof of concept of giving splitting cond sampling between multiple GPUs
This commit is contained in:
parent
871258aa72
commit
7448f02b7c
@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@ -26,6 +27,11 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@ -330,7 +336,7 @@ def module_size(module):
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@ -338,7 +344,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
|
@ -218,6 +218,8 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.is_multigpu_clone = False
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@ -293,6 +295,8 @@ class ModelPatcher:
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.is_multigpu_clone = self.is_multigpu_clone
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
@ -19,6 +19,7 @@ import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
import threading
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
dims = tuple(x_in.shape[2:])
|
||||
@ -130,7 +131,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@ -142,6 +143,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None:
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@ -195,7 +198,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@ -329,6 +334,173 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()]
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
count = 0
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
current_device = devices[count % len(devices)]
|
||||
free_memory = model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
# if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
# to_batch = batch_amount
|
||||
# break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
count += 1
|
||||
|
||||
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
|
||||
|
||||
results: list[thread_result] = []
|
||||
threads: list[threading.Thread] = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
||||
threads.append(new_thread)
|
||||
new_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond in results:
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@ -940,6 +1112,14 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[device] = self.model_patcher
|
||||
for x in multigpu_patchers:
|
||||
multigpu_dict[x.load_device] = x
|
||||
self.model_options["multigpu_clones"] = multigpu_dict
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
@ -950,9 +1130,13 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
Loading…
Reference in New Issue
Block a user