From db8b59ecff7be40377d17ea69487f442b469c536 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Mar 2024 19:04:41 -0400 Subject: [PATCH] Lower memory usage for loras in lowvram mode at the cost of perf. --- comfy/model_management.py | 36 +++------------- comfy/model_patcher.py | 91 +++++++++++++++++++++++++++++++-------- comfy/ops.py | 22 ++++++++++ 3 files changed, 101 insertions(+), 48 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2f0a0a62..66fa918b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -272,7 +272,6 @@ def module_size(module): class LoadedModel: def __init__(self, model): self.model = model - self.model_accelerated = False self.device = model.load_device def model_memory(self): @@ -285,52 +284,27 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - patch_model_to = None - if lowvram_model_memory == 0: - patch_model_to = self.device + patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) try: - self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + if lowvram_model_memory > 0: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if lowvram_model_memory > 0: - logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) - mem_counter = 0 - for m in self.real_model.modules(): - if hasattr(m, "comfy_cast_weights"): - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - module_mem = module_size(m) - if mem_counter + module_mem < lowvram_model_memory: - m.to(self.device) - mem_counter += module_mem - elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode - m.to(self.device) - mem_counter += module_size(m) - logging.warning("lowvram: loaded module regularly {}".format(m)) - - self.model_accelerated = True - if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) return self.real_model def model_unload(self): - if self.model_accelerated: - for m in self.real_model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - - self.model_accelerated = False - self.model.unpatch_model(self.model.offload_device) self.model.model_patches_to(self.model.offload_device) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5e578dff..475fa812 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ class ModelPatcher: self.current_device = current_device self.weight_inplace_update = weight_inplace_update + self.model_lowvram = False def model_size(self): if self.size > 0: @@ -178,6 +179,27 @@ class ModelPatcher: sd.pop(k) return sd + def patch_weight_to_device(self, key, device_to=None): + if key not in self.patches: + return + + weight = comfy.utils.get_attr(self.model, key) + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr_param(self.model, key, out_weight) + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) @@ -191,23 +213,7 @@ class ModelPatcher: logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue - weight = model_sd[key] - - inplace_update = self.weight_inplace_update - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - - if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr_param(self.model, key, out_weight) - del temp_weight + self.patch_weight_to_device(key, device_to) if device_to is not None: self.model.to(device_to) @@ -215,6 +221,47 @@ class ModelPatcher: return self.model + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + self.patch_model(device_to, patch_weights=False) + + logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) + class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + mem_counter = 0 + for n, m in self.model.named_modules(): + lowvram_weight = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + if bias_key in self.patches: + m.bias_function = LowVramPatch(weight_key, self) + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "weight"): + self.patch_weight_to_device(weight_key, device_to) + self.patch_weight_to_device(bias_key, device_to) + m.to(device_to) + mem_counter += comfy.model_management.module_size(m) + logging.debug("lowvram: loaded module regularly {}".format(m)) + + self.model_lowvram = True + return self.model + def calculate_weight(self, patches, weight, key): for p in patches: alpha = p[0] @@ -341,6 +388,16 @@ class ModelPatcher: return weight def unpatch_model(self, device_to=None): + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None + + self.model_lowvram = False + keys = list(self.backup.keys()) if self.weight_inplace_update: diff --git a/comfy/ops.py b/comfy/ops.py index 517688e8..cfdec355 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,13 +24,20 @@ def cast_bias_weight(s, input): non_blocking = comfy.model_management.device_supports_non_blocking(input.device) if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.bias_function is not None: + bias = s.bias_function(bias) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.weight_function is not None: + weight = s.weight_function(weight) return weight, bias class disable_weight_init: class Linear(torch.nn.Linear): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -46,6 +53,9 @@ class disable_weight_init: class Conv2d(torch.nn.Conv2d): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -61,6 +71,9 @@ class disable_weight_init: class Conv3d(torch.nn.Conv3d): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -76,6 +89,9 @@ class disable_weight_init: class GroupNorm(torch.nn.GroupNorm): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -92,6 +108,9 @@ class disable_weight_init: class LayerNorm(torch.nn.LayerNorm): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -111,6 +130,9 @@ class disable_weight_init: class ConvTranspose2d(torch.nn.ConvTranspose2d): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None