Initial exploration of weight zipper

This commit is contained in:
Jedrzej Kosinski 2025-03-24 03:34:42 -05:00
parent 3b19fc76e3
commit c8037ab667
4 changed files with 202 additions and 13 deletions

View File

@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
self.current_patcher: 'ModelPatcher' = None
self.current_patcher: ModelPatcher = None
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.zipper_initialized = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# handle lowvram zipper initialization, if required
if self.model_lowvram and not self.zipper_initialized:
if self.current_patcher:
self.zipper_initialized = True
with self.current_patcher.use_ejected():
loading = self.current_patcher._load_list_lowvram_only()
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:

View File

@ -17,7 +17,7 @@
"""
from __future__ import annotations
from typing import Optional, Callable
from typing import Optional, Callable, TYPE_CHECKING
import torch
import copy
import inspect
@ -26,6 +26,7 @@ import uuid
import collections
import math
import comfy.ops
import comfy.utils
import comfy.float
import comfy.model_management
@ -34,6 +35,9 @@ import comfy.hooks
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction
if TYPE_CHECKING:
from comfy.model_base import BaseModel
def string_to_seed(data):
crc = 0xFFFFFFFF
@ -201,7 +205,7 @@ class MemoryCounter:
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
self.model = model
self.model: BaseModel = model
if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device
@ -568,6 +572,14 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _zipper_dict_lowvram_only(self):
loading = self._load_list_lowvram_only()
def _load_list_lowvram_only(self):
loading = self._load_list()
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@ -583,6 +595,35 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def prepare_teeth(self):
ordered_list = self._load_list_lowvram_only()
prev_i = None
next_i = None
# first, create teeth on modules in list
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.init_tooth(self.load_device, self.offload_device, l[1])
# create teeth linked list
for i in range(len(ordered_list)):
if i+1 == len(ordered_list):
next_i = None
else:
next_i = i+1
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
if prev_i is not None:
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
else:
m.zipper_tooth.start = True
if next_i is not None:
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
prev_i = i
def clean_teeth(self):
ordered_list = self._load_list_lowvram_only()
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.clean_tooth()
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected():
self.unpatch_hooks()
@ -591,6 +632,8 @@ class ModelPatcher:
lowvram_counter = 0
loading = self._load_list()
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
load_completely = []
loading.sort(reverse=True)
for x in loading:
@ -672,6 +715,7 @@ class ModelPatcher:
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
self.model.zipper_initialized = False
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
@ -684,6 +728,9 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
if self.model.model_lowvram:
self.prepare_teeth()
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@ -715,6 +762,7 @@ class ModelPatcher:
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.clean_teeth()
self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0
@ -804,8 +852,10 @@ class ModelPatcher:
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.prepare_teeth()
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

View File

@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
import logging
import comfy.model_management
@ -56,6 +57,79 @@ class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = []
bias_function = []
zipper_init: dict = None
zipper_tooth: ZipperTooth = None
_zipper_tooth: ZipperTooth = None
def init_tooth(self, load_device, offload_device, key: str=None):
if self.zipper_tooth:
self.clean_tooth()
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
def clean_tooth(self):
if self.zipper_tooth:
del self.zipper_tooth
self.zipper_tooth = None
def connect_teeth(self):
if self.zipper_init is not None:
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
self.zipper_dict["prev_zipper_key"] = self.zipper_key
# def zipper_connect(self):
# if self.zipper_dict is not None:
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
class ZipperTooth:
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
self.op = op
self.key: str = key
self.weight_preloaded: torch.Tensor = None
self.bias_preloaded: torch.Tensor = None
self.load_device = load_device
self.offload_device = offload_device
self.start = False
self.prev_tooth: ZipperTooth = None
self.next_tooth: ZipperTooth = None
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
try:
if self.start:
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
return self.weight_preloaded, self.bias_preloaded
finally:
# if self.prev_tooth:
# self.prev_tooth.offload_previous(0)
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
# TODO: queue load of tensors
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
if self.op.bias is not None:
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
if self.next_tooth and teeth_count:
self.next_tooth.preload_next(teeth_count-1)
def offload_previous(self, teeth_count=1):
# TODO: queue offload of tensors
self.weight_preloaded = None
self.bias_preloaded = None
if self.prev_tooth and teeth_count:
self.prev_tooth.offload_previous(teeth_count-1)
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@ -63,7 +137,11 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
#if self.zipper_init:
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
@ -77,7 +155,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@ -91,7 +172,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@ -105,7 +189,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@ -119,7 +206,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@ -134,7 +224,10 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
@ -156,7 +249,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@ -177,7 +273,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@ -197,7 +296,10 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
else:
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):

View File

@ -6,6 +6,7 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
from comfy.ops import CastWeightBiasOp
import torch
from functools import partial
import collections
@ -18,6 +19,7 @@ import comfy.patcher_extension
import comfy.hooks
import scipy.stats
import numpy
import comfy.ops
def add_area_dims(area, num_dims):
@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
do_cleanup = False
if "weight_zipper" not in model_options:
do_cleanup = True
#zipper_dict = {}
model_options["weight_zipper"] = True
loaded_modules = model.current_patcher._load_list_lowvram_only()
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
sum_m = sum([x[0] for x in low_m])
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(m, "comfy_cast_weights"):
m.zipper_tooth = comfy.ops.ZipperTooth
#m.zipper_dict = zipper_dict
m.zipper_key = l[1]
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
if do_cleanup:
zzz = 20
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(l[2], "comfy_cast_weights"):
#m.zipper_dict = None
m.zipper_key = None
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}