From 839ed3368efd0f61a2b986f57fe9e0698fd08e9f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 20:59:15 -0500 Subject: [PATCH] Some improvements to the lowvram unloading. --- comfy/model_patcher.py | 61 +++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index fc232954..f53f1074 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -367,10 +367,7 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) - def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): - mem_counter = 0 - patch_counter = 0 - lowvram_counter = 0 + def _load_list(self): loading = [] for n, m in self.model.named_modules(): params = [] @@ -383,6 +380,13 @@ class ModelPatcher: break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): loading.append((comfy.model_management.module_size(m), n, m, params)) + return loading + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): + mem_counter = 0 + patch_counter = 0 + lowvram_counter = 0 + loading = self._load_list() load_completely = [] loading.sort(reverse=True) @@ -514,14 +518,7 @@ class ModelPatcher: def partially_unload(self, device_to, memory_to_free=0): memory_freed = 0 patch_counter = 0 - unload_list = [] - - for n, m in self.model.named_modules(): - shift_lowvram = False - if hasattr(m, "comfy_cast_weights"): - module_mem = comfy.model_management.module_size(m) - unload_list.append((module_mem, n, m)) - + unload_list = self._load_list() unload_list.sort() for unload in unload_list: if memory_to_free < memory_freed: @@ -529,32 +526,42 @@ class ModelPatcher: module_mem = unload[0] n = unload[1] m = unload[2] - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + params = unload[3] + lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: - for key in [weight_key, bias_key]: + move_weight = True + for param in params: + key = "{}.{}".format(n, param) bk = self.backup.get(key, None) if bk is not None: + if not lowvram_possible: + move_weight = False + break + if bk.inplace_update: comfy.utils.copy_to_param(self.model, key, bk.weight) else: comfy.utils.set_attr_param(self.model, key, bk.weight) self.backup.pop(key) - m.to(device_to) - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if move_weight: + m.to(device_to) + if lowvram_possible: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - m.comfy_patched_weights = False - memory_freed += module_mem - logging.debug("freed {}".format(n)) + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + m.comfy_patched_weights = False + memory_freed += module_mem + logging.debug("freed {}".format(n)) self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter