mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-17 01:13:34 +00:00
Initial exploration of weight zipper
This commit is contained in:
parent
3b19fc76e3
commit
c8037ab667
@ -16,6 +16,7 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
self.current_patcher: ModelPatcher = None
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
self.zipper_initialized = False
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
|
||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
# handle lowvram zipper initialization, if required
|
||||
if self.model_lowvram and not self.zipper_initialized:
|
||||
if self.current_patcher:
|
||||
self.zipper_initialized = True
|
||||
with self.current_patcher.use_ejected():
|
||||
loading = self.current_patcher._load_list_lowvram_only()
|
||||
|
||||
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
|
@ -17,7 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
@ -26,6 +26,7 @@ import uuid
|
||||
import collections
|
||||
import math
|
||||
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import comfy.float
|
||||
import comfy.model_management
|
||||
@ -34,6 +35,9 @@ import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
@ -201,7 +205,7 @@ class MemoryCounter:
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.model: BaseModel = model
|
||||
if not hasattr(self.model, 'device'):
|
||||
logging.debug("Model doesn't have a device attribute.")
|
||||
self.model.device = offload_device
|
||||
@ -568,6 +572,14 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def _zipper_dict_lowvram_only(self):
|
||||
loading = self._load_list_lowvram_only()
|
||||
|
||||
|
||||
def _load_list_lowvram_only(self):
|
||||
loading = self._load_list()
|
||||
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@ -583,6 +595,35 @@ class ModelPatcher:
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
return loading
|
||||
|
||||
def prepare_teeth(self):
|
||||
ordered_list = self._load_list_lowvram_only()
|
||||
prev_i = None
|
||||
next_i = None
|
||||
# first, create teeth on modules in list
|
||||
for l in ordered_list:
|
||||
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||
m.init_tooth(self.load_device, self.offload_device, l[1])
|
||||
# create teeth linked list
|
||||
for i in range(len(ordered_list)):
|
||||
if i+1 == len(ordered_list):
|
||||
next_i = None
|
||||
else:
|
||||
next_i = i+1
|
||||
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
|
||||
if prev_i is not None:
|
||||
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
|
||||
else:
|
||||
m.zipper_tooth.start = True
|
||||
if next_i is not None:
|
||||
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
|
||||
prev_i = i
|
||||
|
||||
def clean_teeth(self):
|
||||
ordered_list = self._load_list_lowvram_only()
|
||||
for l in ordered_list:
|
||||
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||
m.clean_tooth()
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
@ -591,6 +632,8 @@ class ModelPatcher:
|
||||
lowvram_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
@ -672,6 +715,7 @@ class ModelPatcher:
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
self.model.zipper_initialized = False
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
@ -684,6 +728,9 @@ class ModelPatcher:
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
if self.model.model_lowvram:
|
||||
self.prepare_teeth()
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
@ -715,6 +762,7 @@ class ModelPatcher:
|
||||
move_weight_functions(m, device_to)
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
self.clean_teeth()
|
||||
self.model.model_lowvram = False
|
||||
self.model.lowvram_patch_counter = 0
|
||||
|
||||
@ -804,8 +852,10 @@ class ModelPatcher:
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.zipper_initialized = False
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.prepare_teeth()
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
|
120
comfy/ops.py
120
comfy/ops.py
@ -16,6 +16,7 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
@ -56,6 +57,79 @@ class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = []
|
||||
bias_function = []
|
||||
zipper_init: dict = None
|
||||
zipper_tooth: ZipperTooth = None
|
||||
_zipper_tooth: ZipperTooth = None
|
||||
|
||||
def init_tooth(self, load_device, offload_device, key: str=None):
|
||||
if self.zipper_tooth:
|
||||
self.clean_tooth()
|
||||
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
|
||||
|
||||
def clean_tooth(self):
|
||||
if self.zipper_tooth:
|
||||
del self.zipper_tooth
|
||||
self.zipper_tooth = None
|
||||
|
||||
def connect_teeth(self):
|
||||
if self.zipper_init is not None:
|
||||
|
||||
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
|
||||
self.zipper_dict["prev_zipper_key"] = self.zipper_key
|
||||
|
||||
# def zipper_connect(self):
|
||||
# if self.zipper_dict is not None:
|
||||
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
|
||||
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
|
||||
|
||||
class ZipperTooth:
|
||||
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
|
||||
self.op = op
|
||||
self.key: str = key
|
||||
self.weight_preloaded: torch.Tensor = None
|
||||
self.bias_preloaded: torch.Tensor = None
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.start = False
|
||||
|
||||
self.prev_tooth: ZipperTooth = None
|
||||
self.next_tooth: ZipperTooth = None
|
||||
|
||||
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
|
||||
try:
|
||||
if self.start:
|
||||
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
|
||||
return self.weight_preloaded, self.bias_preloaded
|
||||
finally:
|
||||
# if self.prev_tooth:
|
||||
# self.prev_tooth.offload_previous(0)
|
||||
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
|
||||
|
||||
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
|
||||
# TODO: queue load of tensors
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
|
||||
|
||||
if self.op.bias is not None:
|
||||
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
|
||||
|
||||
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
|
||||
if self.next_tooth and teeth_count:
|
||||
self.next_tooth.preload_next(teeth_count-1)
|
||||
|
||||
def offload_previous(self, teeth_count=1):
|
||||
# TODO: queue offload of tensors
|
||||
self.weight_preloaded = None
|
||||
self.bias_preloaded = None
|
||||
if self.prev_tooth and teeth_count:
|
||||
self.prev_tooth.offload_previous(teeth_count-1)
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
@ -63,7 +137,11 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
#if self.zipper_init:
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -77,7 +155,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -91,7 +172,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -105,7 +189,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -119,7 +206,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -134,7 +224,10 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
@ -156,7 +249,10 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
@ -177,7 +273,10 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
@ -197,7 +296,10 @@ class disable_weight_init:
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
if self.zipper_tooth:
|
||||
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
|
||||
else:
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
from comfy.ops import CastWeightBiasOp
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
@ -18,6 +19,7 @@ import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
import comfy.ops
|
||||
|
||||
|
||||
def add_area_dims(area, num_dims):
|
||||
@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
|
||||
do_cleanup = False
|
||||
if "weight_zipper" not in model_options:
|
||||
do_cleanup = True
|
||||
#zipper_dict = {}
|
||||
model_options["weight_zipper"] = True
|
||||
loaded_modules = model.current_patcher._load_list_lowvram_only()
|
||||
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||
sum_m = sum([x[0] for x in low_m])
|
||||
for l in loaded_modules:
|
||||
m: CastWeightBiasOp = l[2]
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.zipper_tooth = comfy.ops.ZipperTooth
|
||||
#m.zipper_dict = zipper_dict
|
||||
m.zipper_key = l[1]
|
||||
|
||||
conds = [cond, uncond_]
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
|
||||
if do_cleanup:
|
||||
zzz = 20
|
||||
for l in loaded_modules:
|
||||
m: CastWeightBiasOp = l[2]
|
||||
if hasattr(l[2], "comfy_cast_weights"):
|
||||
#m.zipper_dict = None
|
||||
m.zipper_key = None
|
||||
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||
|
Loading…
Reference in New Issue
Block a user