mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 14:57:43 +00:00
168 lines
7.4 KiB
Python
168 lines
7.4 KiB
Python
from __future__ import annotations
|
|
import torch
|
|
import logging
|
|
|
|
from collections import namedtuple
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from comfy.model_patcher import ModelPatcher
|
|
import comfy.utils
|
|
import comfy.patcher_extension
|
|
import comfy.model_management
|
|
|
|
|
|
class GPUOptions:
|
|
def __init__(self, device_index: int, relative_speed: float):
|
|
self.device_index = device_index
|
|
self.relative_speed = relative_speed
|
|
|
|
def clone(self):
|
|
return GPUOptions(self.device_index, self.relative_speed)
|
|
|
|
def create_dict(self):
|
|
return {
|
|
"relative_speed": self.relative_speed
|
|
}
|
|
|
|
class GPUOptionsGroup:
|
|
def __init__(self):
|
|
self.options: dict[int, GPUOptions] = {}
|
|
|
|
def add(self, info: GPUOptions):
|
|
self.options[info.device_index] = info
|
|
|
|
def clone(self):
|
|
c = GPUOptionsGroup()
|
|
for opt in self.options.values():
|
|
c.add(opt)
|
|
return c
|
|
|
|
def register(self, model: ModelPatcher):
|
|
opts_dict = {}
|
|
# get devices that are valid for this model
|
|
devices: list[torch.device] = [model.load_device]
|
|
for extra_model in model.get_additional_models_with_key("multigpu"):
|
|
extra_model: ModelPatcher
|
|
devices.append(extra_model.load_device)
|
|
# create dictionary with actual device mapped to its GPUOptions
|
|
device_opts_list: list[GPUOptions] = []
|
|
for device in devices:
|
|
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
|
opts_dict[device] = device_opts.create_dict()
|
|
device_opts_list.append(device_opts)
|
|
# make relative_speed relative to 1.0
|
|
min_speed = min([x.relative_speed for x in device_opts_list])
|
|
for value in opts_dict.values():
|
|
value['relative_speed'] /= min_speed
|
|
model.model_options['multigpu_options'] = opts_dict
|
|
|
|
|
|
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
|
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
|
model = model.clone()
|
|
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
|
skip_devices = set()
|
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
|
if len(multigpu_models) > 0:
|
|
for mm in multigpu_models:
|
|
skip_devices.add(mm.load_device)
|
|
skip_devices = list(skip_devices)
|
|
|
|
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
|
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
|
extra_devices = limit_extra_devices.copy()
|
|
# exclude skipped devices
|
|
for skip in skip_devices:
|
|
if skip in extra_devices:
|
|
extra_devices.remove(skip)
|
|
# create new deepclones
|
|
if len(extra_devices) > 0:
|
|
for device in extra_devices:
|
|
device_patcher = None
|
|
if reuse_loaded:
|
|
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
|
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
|
for lm in loaded_models:
|
|
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
|
device_patcher = lm.clone()
|
|
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
|
break
|
|
if device_patcher is None:
|
|
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
|
device_patcher.is_multigpu_base_clone = True
|
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
|
multigpu_models.append(device_patcher)
|
|
model.set_additional_models("multigpu", multigpu_models)
|
|
model.match_multigpu_clones()
|
|
if gpu_options is None:
|
|
gpu_options = GPUOptionsGroup()
|
|
gpu_options.register(model)
|
|
else:
|
|
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
|
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
|
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
|
# new_multigpu_models = []
|
|
# for m in multigpu_models:
|
|
# if m.load_device in limit_extra_devices:
|
|
# new_multigpu_models.append(m)
|
|
# model.set_additional_models("multigpu", new_multigpu_models)
|
|
# persist skip_devices for use in sampling code
|
|
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
|
# model.model_options["multigpu_skip_devices"] = skip_devices
|
|
return model
|
|
|
|
|
|
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
|
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
|
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
|
opts_dict = model_options['multigpu_options']
|
|
devices = list(model_options['multigpu_clones'].keys())
|
|
speed_per_device = []
|
|
work_per_device = []
|
|
# get sum of each device's relative_speed
|
|
total_speed = 0.0
|
|
for opts in opts_dict.values():
|
|
total_speed += opts['relative_speed']
|
|
# get relative work for each device;
|
|
# obtained by w = (W*r)/R
|
|
for device in devices:
|
|
relative_speed = opts_dict[device]['relative_speed']
|
|
relative_work = (total_work*relative_speed) / total_speed
|
|
speed_per_device.append(relative_speed)
|
|
work_per_device.append(relative_work)
|
|
# relative work must be expressed in whole numbers, but likely is a decimal;
|
|
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
|
work_per_device = round_preserved(work_per_device)
|
|
dict_work_per_device = {}
|
|
for device, relative_work in zip(devices, work_per_device):
|
|
dict_work_per_device[device] = relative_work
|
|
if not return_idle_time:
|
|
return LoadBalance(dict_work_per_device, None)
|
|
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
|
# time here is relative and does not correspond to real-world units
|
|
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
|
# calculate relative time spent by the devices waiting on each other after their work is completed
|
|
idle_time = abs(min(completion_time) - max(completion_time))
|
|
# if need to compare work idle time, need to normalize to a common total work
|
|
if work_normalized:
|
|
idle_time *= (work_normalized/total_work)
|
|
|
|
return LoadBalance(dict_work_per_device, idle_time)
|
|
|
|
def round_preserved(values: list[float]):
|
|
'Round all values in a list, preserving the combined sum of values.'
|
|
# get floor of values; casting to int does it too
|
|
floored = [int(x) for x in values]
|
|
total_floored = sum(floored)
|
|
# get remainder to distribute
|
|
remainder = round(sum(values)) - total_floored
|
|
# pair values with fractional portions
|
|
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
|
# sort by fractional part in descending order
|
|
fractional.sort(key=lambda x: x[1], reverse=True)
|
|
# distribute the remainder
|
|
for i in range(remainder):
|
|
index = fractional[i][0]
|
|
floored[index] += 1
|
|
return floored
|