From c8037ab6679a3d1c3d6953981f95fc5d7633ee0d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 24 Mar 2025 03:34:42 -0500 Subject: [PATCH] Initial exploration of weight zipper --- comfy/model_base.py | 14 ++++- comfy/model_patcher.py | 54 ++++++++++++++++++- comfy/ops.py | 120 +++++++++++++++++++++++++++++++++++++---- comfy/samplers.py | 27 +++++++++- 4 files changed, 202 insertions(+), 13 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 976702b6..ee2fa60e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,6 +16,7 @@ along with this program. If not, see . """ +from __future__ import annotations import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module): self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype self.device = device - self.current_patcher: 'ModelPatcher' = None + self.current_patcher: ModelPatcher = None if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: @@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module): logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor + self.zipper_initialized = False def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module): ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + # handle lowvram zipper initialization, if required + if self.model_lowvram and not self.zipper_initialized: + if self.current_patcher: + self.zipper_initialized = True + with self.current_patcher.use_ejected(): + loading = self.current_patcher._load_list_lowvram_only() + + return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b7cb12df..d609665a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -17,7 +17,7 @@ """ from __future__ import annotations -from typing import Optional, Callable +from typing import Optional, Callable, TYPE_CHECKING import torch import copy import inspect @@ -26,6 +26,7 @@ import uuid import collections import math +import comfy.ops import comfy.utils import comfy.float import comfy.model_management @@ -34,6 +35,9 @@ import comfy.hooks import comfy.patcher_extension from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.comfy_types import UnetWrapperFunction +if TYPE_CHECKING: + from comfy.model_base import BaseModel + def string_to_seed(data): crc = 0xFFFFFFFF @@ -201,7 +205,7 @@ class MemoryCounter: class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size - self.model = model + self.model: BaseModel = model if not hasattr(self.model, 'device'): logging.debug("Model doesn't have a device attribute.") self.model.device = offload_device @@ -568,6 +572,14 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def _zipper_dict_lowvram_only(self): + loading = self._load_list_lowvram_only() + + + def _load_list_lowvram_only(self): + loading = self._load_list() + return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")] + def _load_list(self): loading = [] for n, m in self.model.named_modules(): @@ -583,6 +595,35 @@ class ModelPatcher: loading.append((comfy.model_management.module_size(m), n, m, params)) return loading + def prepare_teeth(self): + ordered_list = self._load_list_lowvram_only() + prev_i = None + next_i = None + # first, create teeth on modules in list + for l in ordered_list: + m: comfy.ops.CastWeightBiasOp = l[2] + m.init_tooth(self.load_device, self.offload_device, l[1]) + # create teeth linked list + for i in range(len(ordered_list)): + if i+1 == len(ordered_list): + next_i = None + else: + next_i = i+1 + m: comfy.ops.CastWeightBiasOp = ordered_list[i][2] + if prev_i is not None: + m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth + else: + m.zipper_tooth.start = True + if next_i is not None: + m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth + prev_i = i + + def clean_teeth(self): + ordered_list = self._load_list_lowvram_only() + for l in ordered_list: + m: comfy.ops.CastWeightBiasOp = l[2] + m.clean_tooth() + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): with self.use_ejected(): self.unpatch_hooks() @@ -591,6 +632,8 @@ class ModelPatcher: lowvram_counter = 0 loading = self._load_list() + logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}") + load_completely = [] loading.sort(reverse=True) for x in loading: @@ -672,6 +715,7 @@ class ModelPatcher: if lowvram_counter > 0: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True + self.model.zipper_initialized = False else: logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False @@ -684,6 +728,9 @@ class ModelPatcher: self.model.model_loaded_weight_memory = mem_counter self.model.current_weight_patches_uuid = self.patches_uuid + if self.model.model_lowvram: + self.prepare_teeth() + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) @@ -715,6 +762,7 @@ class ModelPatcher: move_weight_functions(m, device_to) wipe_lowvram_weight(m) + self.clean_teeth() self.model.model_lowvram = False self.model.lowvram_patch_counter = 0 @@ -804,8 +852,10 @@ class ModelPatcher: logging.debug("freed {}".format(n)) self.model.model_lowvram = True + self.model.zipper_initialized = False self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed + self.prepare_teeth() return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): diff --git a/comfy/ops.py b/comfy/ops.py index ced46101..cf2ae864 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -16,6 +16,7 @@ along with this program. If not, see . """ +from __future__ import annotations import torch import logging import comfy.model_management @@ -56,6 +57,79 @@ class CastWeightBiasOp: comfy_cast_weights = False weight_function = [] bias_function = [] + zipper_init: dict = None + zipper_tooth: ZipperTooth = None + _zipper_tooth: ZipperTooth = None + + def init_tooth(self, load_device, offload_device, key: str=None): + if self.zipper_tooth: + self.clean_tooth() + self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key) + + def clean_tooth(self): + if self.zipper_tooth: + del self.zipper_tooth + self.zipper_tooth = None + + def connect_teeth(self): + if self.zipper_init is not None: + + self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None)) + self.zipper_dict["prev_zipper_key"] = self.zipper_key + + # def zipper_connect(self): + # if self.zipper_dict is not None: + # self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None)) + # self.zipper_dict["prev_zipper_key"] = self.zipper_key + +class ZipperTooth: + def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None): + self.op = op + self.key: str = key + self.weight_preloaded: torch.Tensor = None + self.bias_preloaded: torch.Tensor = None + self.load_device = load_device + self.offload_device = offload_device + self.start = False + + self.prev_tooth: ZipperTooth = None + self.next_tooth: ZipperTooth = None + + def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None): + try: + if self.start: + return cast_bias_weight(self.op, input, dtype, device, bias_dtype) + return self.weight_preloaded, self.bias_preloaded + finally: + # if self.prev_tooth: + # self.prev_tooth.offload_previous(0) + self.next_tooth.preload_next(0, input, dtype, device, bias_dtype) + + def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None): + # TODO: queue load of tensors + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + + non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device) + + if self.op.bias is not None: + self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking) + + self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking) + if self.next_tooth and teeth_count: + self.next_tooth.preload_next(teeth_count-1) + + def offload_previous(self, teeth_count=1): + # TODO: queue offload of tensors + self.weight_preloaded = None + self.bias_preloaded = None + if self.prev_tooth and teeth_count: + self.prev_tooth.offload_previous(teeth_count-1) class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): @@ -63,7 +137,11 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + #if self.zipper_init: + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): @@ -77,7 +155,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): @@ -91,7 +172,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): @@ -105,7 +189,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): @@ -119,7 +206,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): @@ -134,7 +224,10 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) else: weight = None bias = None @@ -156,7 +249,10 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) @@ -177,7 +273,10 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(input) + else: + weight, bias = cast_bias_weight(self, input) return torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) @@ -197,7 +296,10 @@ class disable_weight_init: output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) + if self.zipper_tooth: + weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype) + else: + weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): diff --git a/comfy/samplers.py b/comfy/samplers.py index 10728bd1..afca345b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel from comfy.controlnet import ControlBase + from comfy.ops import CastWeightBiasOp import torch from functools import partial import collections @@ -18,6 +19,7 @@ import comfy.patcher_extension import comfy.hooks import scipy.stats import numpy +import comfy.ops def add_area_dims(area, num_dims): @@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o #The main sampling function shared by all the samplers #Returns denoised -def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): +def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None else: uncond_ = uncond + do_cleanup = False + if "weight_zipper" not in model_options: + do_cleanup = True + #zipper_dict = {} + model_options["weight_zipper"] = True + loaded_modules = model.current_patcher._load_list_lowvram_only() + low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")] + sum_m = sum([x[0] for x in low_m]) + for l in loaded_modules: + m: CastWeightBiasOp = l[2] + if hasattr(m, "comfy_cast_weights"): + m.zipper_tooth = comfy.ops.ZipperTooth + #m.zipper_dict = zipper_dict + m.zipper_key = l[1] + conds = [cond, uncond_] out = calc_cond_batch(model, conds, x, timestep, model_options) + if do_cleanup: + zzz = 20 + for l in loaded_modules: + m: CastWeightBiasOp = l[2] + if hasattr(l[2], "comfy_cast_weights"): + #m.zipper_dict = None + m.zipper_key = None + for fn in model_options.get("sampler_pre_cfg_function", []): args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "model": model, "model_options": model_options}