Lower memory usage for loras in lowvram mode at the cost of perf.

This commit is contained in:
comfyanonymous 2024-03-13 19:04:41 -04:00
parent eda8704386
commit db8b59ecff
3 changed files with 101 additions and 48 deletions

View File

@ -272,7 +272,6 @@ def module_size(module):
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.model_accelerated = False
self.device = model.load_device self.device = model.load_device
def model_memory(self): def model_memory(self):
@ -285,52 +284,27 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0):
patch_model_to = None patch_model_to = self.device
if lowvram_model_memory == 0:
patch_model_to = self.device
self.model.model_patches_to(self.device) self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype()) self.model.model_patches_to(self.model.model_dtype())
try: try:
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU if lowvram_model_memory > 0:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to)
except Exception as e: except Exception as e:
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
self.model_unload() self.model_unload()
raise e raise e
if lowvram_model_memory > 0:
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
logging.warning("lowvram: loaded module regularly {}".format(m))
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
return self.real_model return self.real_model
def model_unload(self): def model_unload(self):
if self.model_accelerated:
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)

View File

@ -24,6 +24,7 @@ class ModelPatcher:
self.current_device = current_device self.current_device = current_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -178,6 +179,27 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None):
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True): def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches: for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
@ -191,23 +213,7 @@ class ModelPatcher:
logging.warning("could not patch. key doesn't exist in model: {}".format(key)) logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue continue
weight = model_sd[key] self.patch_weight_to_device(key, device_to)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
del temp_weight
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
@ -215,6 +221,47 @@ class ModelPatcher:
return self.model return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += comfy.model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
for p in patches: for p in patches:
alpha = p[0] alpha = p[0]
@ -341,6 +388,16 @@ class ModelPatcher:
return weight return weight
def unpatch_model(self, device_to=None): def unpatch_model(self, device_to=None):
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
self.model_lowvram = False
keys = list(self.backup.keys()) keys = list(self.backup.keys())
if self.weight_inplace_update: if self.weight_inplace_update:

View File

@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
non_blocking = comfy.model_management.device_supports_non_blocking(input.device) non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None: if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
bias = s.bias_function(bias)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.weight_function is not None:
weight = s.weight_function(weight)
return weight, bias return weight, bias
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None
bias_function = None
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -46,6 +53,9 @@ class disable_weight_init:
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None
bias_function = None
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -61,6 +71,9 @@ class disable_weight_init:
class Conv3d(torch.nn.Conv3d): class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None
bias_function = None
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -76,6 +89,9 @@ class disable_weight_init:
class GroupNorm(torch.nn.GroupNorm): class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None
bias_function = None
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -92,6 +108,9 @@ class disable_weight_init:
class LayerNorm(torch.nn.LayerNorm): class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None
bias_function = None
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -111,6 +130,9 @@ class disable_weight_init:
class ConvTranspose2d(torch.nn.ConvTranspose2d): class ConvTranspose2d(torch.nn.ConvTranspose2d):
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None
bias_function = None
def reset_parameters(self): def reset_parameters(self):
return None return None