diff --git a/comfy/model_base.py b/comfy/model_base.py
index 976702b6..ee2fa60e 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -16,6 +16,7 @@
along with this program. If not, see .
"""
+from __future__ import annotations
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
- self.current_patcher: 'ModelPatcher' = None
+ self.current_patcher: ModelPatcher = None
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
@@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
+ self.zipper_initialized = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
+ # handle lowvram zipper initialization, if required
+ if self.model_lowvram and not self.zipper_initialized:
+ if self.current_patcher:
+ self.zipper_initialized = True
+ with self.current_patcher.use_ejected():
+ loading = self.current_patcher._load_list_lowvram_only()
+
+ return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
+
+ def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index b7cb12df..d609665a 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -17,7 +17,7 @@
"""
from __future__ import annotations
-from typing import Optional, Callable
+from typing import Optional, Callable, TYPE_CHECKING
import torch
import copy
import inspect
@@ -26,6 +26,7 @@ import uuid
import collections
import math
+import comfy.ops
import comfy.utils
import comfy.float
import comfy.model_management
@@ -34,6 +35,9 @@ import comfy.hooks
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction
+if TYPE_CHECKING:
+ from comfy.model_base import BaseModel
+
def string_to_seed(data):
crc = 0xFFFFFFFF
@@ -201,7 +205,7 @@ class MemoryCounter:
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
- self.model = model
+ self.model: BaseModel = model
if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device
@@ -568,6 +572,14 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
+ def _zipper_dict_lowvram_only(self):
+ loading = self._load_list_lowvram_only()
+
+
+ def _load_list_lowvram_only(self):
+ loading = self._load_list()
+ return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
+
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -583,6 +595,35 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
+ def prepare_teeth(self):
+ ordered_list = self._load_list_lowvram_only()
+ prev_i = None
+ next_i = None
+ # first, create teeth on modules in list
+ for l in ordered_list:
+ m: comfy.ops.CastWeightBiasOp = l[2]
+ m.init_tooth(self.load_device, self.offload_device, l[1])
+ # create teeth linked list
+ for i in range(len(ordered_list)):
+ if i+1 == len(ordered_list):
+ next_i = None
+ else:
+ next_i = i+1
+ m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
+ if prev_i is not None:
+ m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
+ else:
+ m.zipper_tooth.start = True
+ if next_i is not None:
+ m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
+ prev_i = i
+
+ def clean_teeth(self):
+ ordered_list = self._load_list_lowvram_only()
+ for l in ordered_list:
+ m: comfy.ops.CastWeightBiasOp = l[2]
+ m.clean_tooth()
+
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected():
self.unpatch_hooks()
@@ -591,6 +632,8 @@ class ModelPatcher:
lowvram_counter = 0
loading = self._load_list()
+ logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
+
load_completely = []
loading.sort(reverse=True)
for x in loading:
@@ -672,6 +715,7 @@ class ModelPatcher:
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
+ self.model.zipper_initialized = False
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
@@ -684,6 +728,9 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
+ if self.model.model_lowvram:
+ self.prepare_teeth()
+
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@@ -715,6 +762,7 @@ class ModelPatcher:
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
+ self.clean_teeth()
self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0
@@ -804,8 +852,10 @@ class ModelPatcher:
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
+ self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
+ self.prepare_teeth()
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
diff --git a/comfy/ops.py b/comfy/ops.py
index ced46101..cf2ae864 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -16,6 +16,7 @@
along with this program. If not, see .
"""
+from __future__ import annotations
import torch
import logging
import comfy.model_management
@@ -56,6 +57,79 @@ class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = []
bias_function = []
+ zipper_init: dict = None
+ zipper_tooth: ZipperTooth = None
+ _zipper_tooth: ZipperTooth = None
+
+ def init_tooth(self, load_device, offload_device, key: str=None):
+ if self.zipper_tooth:
+ self.clean_tooth()
+ self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
+
+ def clean_tooth(self):
+ if self.zipper_tooth:
+ del self.zipper_tooth
+ self.zipper_tooth = None
+
+ def connect_teeth(self):
+ if self.zipper_init is not None:
+
+ self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
+ self.zipper_dict["prev_zipper_key"] = self.zipper_key
+
+ # def zipper_connect(self):
+ # if self.zipper_dict is not None:
+ # self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
+ # self.zipper_dict["prev_zipper_key"] = self.zipper_key
+
+class ZipperTooth:
+ def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
+ self.op = op
+ self.key: str = key
+ self.weight_preloaded: torch.Tensor = None
+ self.bias_preloaded: torch.Tensor = None
+ self.load_device = load_device
+ self.offload_device = offload_device
+ self.start = False
+
+ self.prev_tooth: ZipperTooth = None
+ self.next_tooth: ZipperTooth = None
+
+ def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
+ try:
+ if self.start:
+ return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
+ return self.weight_preloaded, self.bias_preloaded
+ finally:
+ # if self.prev_tooth:
+ # self.prev_tooth.offload_previous(0)
+ self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
+
+ def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
+ # TODO: queue load of tensors
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if bias_dtype is None:
+ bias_dtype = dtype
+ if device is None:
+ device = input.device
+
+ non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
+
+ if self.op.bias is not None:
+ self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
+
+ self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
+ if self.next_tooth and teeth_count:
+ self.next_tooth.preload_next(teeth_count-1)
+
+ def offload_previous(self, teeth_count=1):
+ # TODO: queue offload of tensors
+ self.weight_preloaded = None
+ self.bias_preloaded = None
+ if self.prev_tooth and teeth_count:
+ self.prev_tooth.offload_previous(teeth_count-1)
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -63,7 +137,11 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ #if self.zipper_init:
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -77,7 +155,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -91,7 +172,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -105,7 +189,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -119,7 +206,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@@ -134,7 +224,10 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
@@ -156,7 +249,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@@ -177,7 +273,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
- weight, bias = cast_bias_weight(self, input)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(input)
+ else:
+ weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@@ -197,7 +296,10 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
- weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
+ if self.zipper_tooth:
+ weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
+ else:
+ weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 10728bd1..afca345b 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
+ from comfy.ops import CastWeightBiasOp
import torch
from functools import partial
import collections
@@ -18,6 +19,7 @@ import comfy.patcher_extension
import comfy.hooks
import scipy.stats
import numpy
+import comfy.ops
def add_area_dims(area, num_dims):
@@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
#The main sampling function shared by all the samplers
#Returns denoised
-def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
+def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
+ do_cleanup = False
+ if "weight_zipper" not in model_options:
+ do_cleanup = True
+ #zipper_dict = {}
+ model_options["weight_zipper"] = True
+ loaded_modules = model.current_patcher._load_list_lowvram_only()
+ low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
+ sum_m = sum([x[0] for x in low_m])
+ for l in loaded_modules:
+ m: CastWeightBiasOp = l[2]
+ if hasattr(m, "comfy_cast_weights"):
+ m.zipper_tooth = comfy.ops.ZipperTooth
+ #m.zipper_dict = zipper_dict
+ m.zipper_key = l[1]
+
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
+ if do_cleanup:
+ zzz = 20
+ for l in loaded_modules:
+ m: CastWeightBiasOp = l[2]
+ if hasattr(l[2], "comfy_cast_weights"):
+ #m.zipper_dict = None
+ m.zipper_key = None
+
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}