mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-18 01:53:31 +00:00
Initial MultiGPU support for controlnets
This commit is contained in:
parent
5db4277449
commit
46969c380a
@ -15,13 +15,14 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import copy
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
@ -36,7 +37,7 @@ import comfy.cldm.mmdit
|
|||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Union
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.hooks import HookGroup
|
from comfy.hooks import HookGroup
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ class ControlBase:
|
|||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet: Union[ControlBase, None] = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
@ -84,6 +85,7 @@ class ControlBase:
|
|||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.extra_hooks: HookGroup = None
|
self.extra_hooks: HookGroup = None
|
||||||
self.preprocess_image = lambda a: a
|
self.preprocess_image = lambda a: a
|
||||||
|
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -117,10 +119,33 @@ class ControlBase:
|
|||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = []
|
||||||
|
for device_cnet in self.multigpu_clones.values():
|
||||||
|
out += device_cnet.get_models()
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_models_only_self(self):
|
||||||
|
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||||
|
try:
|
||||||
|
orig_previous_controlnet = self.previous_controlnet
|
||||||
|
self.previous_controlnet = None
|
||||||
|
return self.get_models()
|
||||||
|
finally:
|
||||||
|
self.previous_controlnet = orig_previous_controlnet
|
||||||
|
|
||||||
|
def get_instance_for_device(self, device):
|
||||||
|
'Returns instance of this Control object intended for selected device.'
|
||||||
|
return self.multigpu_clones.get(device, self)
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
'''
|
||||||
|
Create deep clone of Control object where model(s) is set to other devices.
|
||||||
|
|
||||||
|
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||||
|
'''
|
||||||
|
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||||
|
|
||||||
def get_extra_hooks(self):
|
def get_extra_hooks(self):
|
||||||
out = []
|
out = []
|
||||||
if self.extra_hooks is not None:
|
if self.extra_hooks is not None:
|
||||||
@ -129,7 +154,7 @@ class ControlBase:
|
|||||||
out += self.previous_controlnet.get_extra_hooks()
|
out += self.previous_controlnet.get_extra_hooks()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def copy_to(self, c):
|
def copy_to(self, c: ControlBase):
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
@ -280,6 +305,14 @@ class ControlNet(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
c = self.copy()
|
||||||
|
c.control_model = copy.deepcopy(c.control_model)
|
||||||
|
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||||
|
if autoregister:
|
||||||
|
self.multigpu_clones[load_device] = c
|
||||||
|
return c
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = super().get_models()
|
out = super().get_models()
|
||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
@ -810,6 +843,14 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
c = self.copy()
|
||||||
|
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||||
|
c.device = load_device
|
||||||
|
if autoregister:
|
||||||
|
self.multigpu_clones[load_device] = c
|
||||||
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||||
@ -427,7 +429,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
cond_or_uncond = []
|
cond_or_uncond = []
|
||||||
uuids = []
|
uuids = []
|
||||||
area = []
|
area = []
|
||||||
control = None
|
control: ControlBase = None
|
||||||
patches = None
|
patches = None
|
||||||
for x in to_batch:
|
for x in to_batch:
|
||||||
o = x
|
o = x
|
||||||
@ -473,7 +475,8 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
if control is not None:
|
if control is not None:
|
||||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
device_control = control.get_instance_for_device(device)
|
||||||
|
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
if 'model_function_wrapper' in model_options:
|
||||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||||
@ -799,6 +802,8 @@ def pre_run_control(model, conds):
|
|||||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
|
for device_cnet in x['control'].multigpu_clones.values():
|
||||||
|
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@ -1080,6 +1085,48 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
wc_list[i] = wc_list[i].to(cast)
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher):
|
||||||
|
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||||
|
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) == 0:
|
||||||
|
return
|
||||||
|
extra_devices = [x.load_device for x in multigpu_models]
|
||||||
|
# handle controlnets
|
||||||
|
controlnets: set[ControlBase] = set()
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
if 'control' in kk:
|
||||||
|
controlnets.add(kk['control'])
|
||||||
|
if len(controlnets) > 0:
|
||||||
|
# first, unload all controlnet clones
|
||||||
|
for cnet in list(controlnets):
|
||||||
|
cnet_models = cnet.get_models()
|
||||||
|
for cm in cnet_models:
|
||||||
|
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||||
|
|
||||||
|
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||||
|
for cnet in controlnets:
|
||||||
|
curr_cnet = cnet
|
||||||
|
while curr_cnet is not None:
|
||||||
|
for device in extra_devices:
|
||||||
|
if device not in curr_cnet.multigpu_clones:
|
||||||
|
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||||
|
curr_cnet = curr_cnet.previous_controlnet
|
||||||
|
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||||
|
for cnet in controlnets:
|
||||||
|
curr_cnet = cnet
|
||||||
|
while curr_cnet is not None:
|
||||||
|
prev_cnet = curr_cnet.previous_controlnet
|
||||||
|
for device in extra_devices:
|
||||||
|
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||||
|
prev_device_cnet = None
|
||||||
|
if prev_cnet is not None:
|
||||||
|
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||||
|
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||||
|
curr_cnet = prev_cnet
|
||||||
|
# TODO: handle gligen
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher: ModelPatcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher = model_patcher
|
self.model_patcher = model_patcher
|
||||||
@ -1122,6 +1169,7 @@ class CFGGuider:
|
|||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
|
preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher)
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user