mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Add add_weight_wrapper function to model patcher.
Functions can now easily be added to wrap/modify model weights.
This commit is contained in:
parent
d9f0fcdb0c
commit
ab888e1e0b
@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
|
|||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
del m.prev_comfy_cast_weights
|
del m.prev_comfy_cast_weights
|
||||||
m.weight_function = None
|
|
||||||
m.bias_function = None
|
if hasattr(m, "weight_function"):
|
||||||
|
m.weight_function = []
|
||||||
|
|
||||||
|
if hasattr(m, "bias_function"):
|
||||||
|
m.bias_function = []
|
||||||
|
|
||||||
|
def move_weight_functions(m, device):
|
||||||
|
if device is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
memory = 0
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
for f in m.weight_function:
|
||||||
|
if hasattr(f, "move_to"):
|
||||||
|
memory += f.move_to(device=device)
|
||||||
|
|
||||||
|
if hasattr(m, "bias_function"):
|
||||||
|
for f in m.bias_function:
|
||||||
|
if hasattr(f, "move_to"):
|
||||||
|
memory += f.move_to(device=device)
|
||||||
|
return memory
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, patches):
|
def __init__(self, key, patches):
|
||||||
@ -192,6 +212,7 @@ class ModelPatcher:
|
|||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
|
self.weight_wrapper_patches = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
self.model_size()
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -250,6 +271,7 @@ class ModelPatcher:
|
|||||||
n.patches_uuid = self.patches_uuid
|
n.patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
@ -402,6 +424,10 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
def add_weight_wrapper(self, name, function):
|
||||||
|
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||||
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||||
"""Retrieves a nested attribute from an object using dot notation considering
|
"""Retrieves a nested attribute from an object using dot notation considering
|
||||||
object patches.
|
object patches.
|
||||||
@ -566,6 +592,9 @@ class ModelPatcher:
|
|||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
@ -573,34 +602,42 @@ class ModelPatcher:
|
|||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
|
|
||||||
if lowvram_weight:
|
if lowvram_weight:
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.weight_function = []
|
||||||
|
m.bias_function = []
|
||||||
|
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
if m.comfy_cast_weights:
|
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
load_completely.append((module_mem, n, m, params))
|
load_completely.append((module_mem, n, m, params))
|
||||||
|
|
||||||
|
if weight_key in self.weight_wrapper_patches:
|
||||||
|
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
||||||
|
|
||||||
|
if bias_key in self.weight_wrapper_patches:
|
||||||
|
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
||||||
|
|
||||||
|
mem_counter += move_weight_functions(m, device_to)
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -662,6 +699,7 @@ class ModelPatcher:
|
|||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
@ -729,12 +767,13 @@ class ModelPatcher:
|
|||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
33
comfy/ops.py
33
comfy/ops.py
@ -38,21 +38,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = s.bias_function is not None
|
has_function = len(s.bias_function) > 0
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
for f in s.bias_function:
|
||||||
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = s.weight_function is not None
|
has_function = len(s.weight_function) > 0
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
for f in s.weight_function:
|
||||||
|
weight = f(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
weight_function = None
|
weight_function = []
|
||||||
bias_function = None
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
@ -64,7 +66,7 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -78,7 +80,7 @@ class disable_weight_init:
|
|||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -92,7 +94,7 @@ class disable_weight_init:
|
|||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -106,7 +108,7 @@ class disable_weight_init:
|
|||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -120,12 +122,11 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -139,7 +140,7 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -160,7 +161,7 @@ class disable_weight_init:
|
|||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -181,7 +182,7 @@ class disable_weight_init:
|
|||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -199,7 +200,7 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
if "out_dtype" in kwargs:
|
if "out_dtype" in kwargs:
|
||||||
|
Loading…
Reference in New Issue
Block a user