Initial MultiGPU support for controlnets

This commit is contained in:
Jedrzej Kosinski 2025-01-24 05:39:38 -06:00
parent 5db4277449
commit 46969c380a
2 changed files with 95 additions and 6 deletions

View File

@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
from enum import Enum from enum import Enum
import math import math
import os import os
import logging import logging
import copy
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
@ -36,7 +37,7 @@ import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.hooks import HookGroup from comfy.hooks import HookGroup
@ -76,7 +77,7 @@ class ControlBase:
self.compression_ratio = 8 self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact' self.upscale_algorithm = 'nearest-exact'
self.extra_args = {} self.extra_args = {}
self.previous_controlnet = None self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
self.concat_mask = False self.concat_mask = False
@ -84,6 +85,7 @@ class ControlBase:
self.extra_concat = None self.extra_concat = None
self.extra_hooks: HookGroup = None self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -117,10 +119,33 @@ class ControlBase:
def get_models(self): def get_models(self):
out = [] out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models()
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models() out += self.previous_controlnet.get_models()
return out return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
try:
orig_previous_controlnet = self.previous_controlnet
self.previous_controlnet = None
return self.get_models()
finally:
self.previous_controlnet = orig_previous_controlnet
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self): def get_extra_hooks(self):
out = [] out = []
if self.extra_hooks is not None: if self.extra_hooks is not None:
@ -129,7 +154,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks() out += self.previous_controlnet.get_extra_hooks()
return out return out
def copy_to(self, c): def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
@ -280,6 +305,14 @@ class ControlNet(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self): def get_models(self):
out = super().get_models() out = super().get_models()
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
@ -810,6 +843,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8 compression_ratio = 8
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'

View File

@ -1,4 +1,6 @@
from __future__ import annotations from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple from typing import TYPE_CHECKING, Callable, NamedTuple
@ -427,7 +429,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
cond_or_uncond = [] cond_or_uncond = []
uuids = [] uuids = []
area = [] area = []
control = None control: ControlBase = None
patches = None patches = None
for x in to_batch: for x in to_batch:
o = x o = x
@ -473,7 +475,8 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options: if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
@ -799,6 +802,8 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []
@ -1080,6 +1085,48 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
wc_list[i] = wc_list[i].to(cast) wc_list[i] = wc_list[i].to(cast)
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# TODO: handle gligen
class CFGGuider: class CFGGuider:
def __init__(self, model_patcher: ModelPatcher): def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher self.model_patcher = model_patcher
@ -1122,6 +1169,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32)) return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher)
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device