From ab888e1e0b8a7558081713241172d0a38f837e16 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Feb 2025 05:49:00 -0500 Subject: [PATCH] Add add_weight_wrapper function to model patcher. Functions can now easily be added to wrap/modify model weights. --- comfy/model_patcher.py | 61 ++++++++++++++++++++++++++++++++++-------- comfy/ops.py | 33 ++++++++++++----------- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0501f7b38..aee0164c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -96,8 +96,28 @@ def wipe_lowvram_weight(m): if hasattr(m, "prev_comfy_cast_weights"): m.comfy_cast_weights = m.prev_comfy_cast_weights del m.prev_comfy_cast_weights - m.weight_function = None - m.bias_function = None + + if hasattr(m, "weight_function"): + m.weight_function = [] + + if hasattr(m, "bias_function"): + m.bias_function = [] + +def move_weight_functions(m, device): + if device is None: + return 0 + + memory = 0 + if hasattr(m, "weight_function"): + for f in m.weight_function: + if hasattr(f, "move_to"): + memory += f.move_to(device=device) + + if hasattr(m, "bias_function"): + for f in m.bias_function: + if hasattr(f, "move_to"): + memory += f.move_to(device=device) + return memory class LowVramPatch: def __init__(self, key, patches): @@ -192,6 +212,7 @@ class ModelPatcher: self.backup = {} self.object_patches = {} self.object_patches_backup = {} + self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} self.model_size() self.load_device = load_device @@ -250,6 +271,7 @@ class ModelPatcher: n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() + n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup @@ -402,6 +424,10 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def add_weight_wrapper(self, name, function): + self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] + self.patches_uuid = uuid.uuid4() + def get_model_object(self, name: str) -> torch.nn.Module: """Retrieves a nested attribute from an object using dot notation considering object patches. @@ -566,6 +592,9 @@ class ModelPatcher: lowvram_weight = False + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if not full_load and hasattr(m, "comfy_cast_weights"): if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True @@ -573,34 +602,42 @@ class ModelPatcher: if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if lowvram_weight: + if hasattr(m, "comfy_cast_weights"): + m.weight_function = [] + m.bias_function = [] + if weight_key in self.patches: if force_patch_weights: self.patch_weight_to_device(weight_key) else: - m.weight_function = LowVramPatch(weight_key, self.patches) + m.weight_function = [LowVramPatch(weight_key, self.patches)] patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: - m.bias_function = LowVramPatch(bias_key, self.patches) + m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True else: if hasattr(m, "comfy_cast_weights"): - if m.comfy_cast_weights: - wipe_lowvram_weight(m) + wipe_lowvram_weight(m) if full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + if weight_key in self.weight_wrapper_patches: + m.weight_function.extend(self.weight_wrapper_patches[weight_key]) + + if bias_key in self.weight_wrapper_patches: + m.bias_function.extend(self.weight_wrapper_patches[bias_key]) + + mem_counter += move_weight_functions(m, device_to) + load_completely.sort(reverse=True) for x in load_completely: n = x[1] @@ -662,6 +699,7 @@ class ModelPatcher: self.unpatch_hooks() if self.model.model_lowvram: for m in self.model.modules(): + move_weight_functions(m, device_to) wipe_lowvram_weight(m) self.model.model_lowvram = False @@ -729,12 +767,13 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: m.to(device_to) + module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) + m.weight_function.append(LowVramPatch(weight_key, self.patches)) patch_counter += 1 if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) + m.bias_function.append(LowVramPatch(bias_key, self.patches)) patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights diff --git a/comfy/ops.py b/comfy/ops.py index 06be6b48b..30014477e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -38,21 +38,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: - has_function = s.bias_function is not None + has_function = len(s.bias_function) > 0 bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: - bias = s.bias_function(bias) + for f in s.bias_function: + bias = f(bias) - has_function = s.weight_function is not None + has_function = len(s.weight_function) > 0 weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: - weight = s.weight_function(weight) + for f in s.weight_function: + weight = f(weight) return weight, bias class CastWeightBiasOp: comfy_cast_weights = False - weight_function = None - bias_function = None + weight_function = [] + bias_function = [] class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): @@ -64,7 +66,7 @@ class disable_weight_init: return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -78,7 +80,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -92,7 +94,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -106,7 +108,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -120,12 +122,11 @@ class disable_weight_init: return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -139,7 +140,7 @@ class disable_weight_init: return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -160,7 +161,7 @@ class disable_weight_init: output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -181,7 +182,7 @@ class disable_weight_init: output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -199,7 +200,7 @@ class disable_weight_init: return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: if "out_dtype" in kwargs: