Add add_weight_wrapper function to model patcher.

Functions can now easily be added to wrap/modify model weights.
This commit is contained in:
comfyanonymous 2025-02-12 05:49:00 -05:00
parent d9f0fcdb0c
commit ab888e1e0b
2 changed files with 67 additions and 27 deletions

View File

@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
if hasattr(m, "weight_function"):
m.weight_function = []
if hasattr(m, "bias_function"):
m.bias_function = []
def move_weight_functions(m, device):
if device is None:
return 0
memory = 0
if hasattr(m, "weight_function"):
for f in m.weight_function:
if hasattr(f, "move_to"):
memory += f.move_to(device=device)
if hasattr(m, "bias_function"):
for f in m.bias_function:
if hasattr(f, "move_to"):
memory += f.move_to(device=device)
return memory
class LowVramPatch:
def __init__(self, key, patches):
@ -192,6 +212,7 @@ class ModelPatcher:
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
@ -250,6 +271,7 @@ class ModelPatcher:
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
@ -402,6 +424,10 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
@ -566,6 +592,9 @@ class ModelPatcher:
lowvram_weight = False
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
@ -573,34 +602,42 @@ class ModelPatcher:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
m.bias_function = []
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self.patches)
m.weight_function = [LowVramPatch(weight_key, self.patches)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self.patches)
m.bias_function = [LowVramPatch(bias_key, self.patches)]
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
if weight_key in self.weight_wrapper_patches:
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
if bias_key in self.weight_wrapper_patches:
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
mem_counter += move_weight_functions(m, device_to)
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
@ -662,6 +699,7 @@ class ModelPatcher:
self.unpatch_hooks()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.model.model_lowvram = False
@ -729,12 +767,13 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
m.weight_function.append(LowVramPatch(weight_key, self.patches))
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights

View File

@ -38,21 +38,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = s.bias_function is not None
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
for f in s.bias_function:
bias = f(bias)
has_function = s.weight_function is not None
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
for f in s.weight_function:
weight = f(weight)
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None
weight_function = []
bias_function = []
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@ -64,7 +66,7 @@ class disable_weight_init:
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -78,7 +80,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -92,7 +94,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -106,7 +108,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -120,12 +122,11 @@ class disable_weight_init:
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -139,7 +140,7 @@ class disable_weight_init:
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -160,7 +161,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -181,7 +182,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -199,7 +200,7 @@ class disable_weight_init:
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs: