mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Lower memory usage for loras in lowvram mode at the cost of perf.
This commit is contained in:
parent
eda8704386
commit
db8b59ecff
@ -272,7 +272,6 @@ def module_size(module):
|
|||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_accelerated = False
|
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
|
|
||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
@ -285,52 +284,27 @@ class LoadedModel:
|
|||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, lowvram_model_memory=0):
|
||||||
patch_model_to = None
|
patch_model_to = self.device
|
||||||
if lowvram_model_memory == 0:
|
|
||||||
patch_model_to = self.device
|
|
||||||
|
|
||||||
self.model.model_patches_to(self.device)
|
self.model.model_patches_to(self.device)
|
||||||
self.model.model_patches_to(self.model.model_dtype())
|
self.model.model_patches_to(self.model.model_dtype())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
|
if lowvram_model_memory > 0:
|
||||||
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
||||||
|
else:
|
||||||
|
self.real_model = self.model.patch_model(device_to=patch_model_to)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if lowvram_model_memory > 0:
|
|
||||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
|
||||||
mem_counter = 0
|
|
||||||
for m in self.real_model.modules():
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
|
||||||
m.comfy_cast_weights = True
|
|
||||||
module_mem = module_size(m)
|
|
||||||
if mem_counter + module_mem < lowvram_model_memory:
|
|
||||||
m.to(self.device)
|
|
||||||
mem_counter += module_mem
|
|
||||||
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
|
||||||
m.to(self.device)
|
|
||||||
mem_counter += module_size(m)
|
|
||||||
logging.warning("lowvram: loaded module regularly {}".format(m))
|
|
||||||
|
|
||||||
self.model_accelerated = True
|
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||||
|
|
||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
def model_unload(self):
|
def model_unload(self):
|
||||||
if self.model_accelerated:
|
|
||||||
for m in self.real_model.modules():
|
|
||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
|
||||||
del m.prev_comfy_cast_weights
|
|
||||||
|
|
||||||
self.model_accelerated = False
|
|
||||||
|
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class ModelPatcher:
|
|||||||
self.current_device = current_device
|
self.current_device = current_device
|
||||||
|
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
|
self.model_lowvram = False
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
@ -178,6 +179,27 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def patch_weight_to_device(self, key, device_to=None):
|
||||||
|
if key not in self.patches:
|
||||||
|
return
|
||||||
|
|
||||||
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
|
|
||||||
|
inplace_update = self.weight_inplace_update
|
||||||
|
|
||||||
|
if key not in self.backup:
|
||||||
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
if inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
def patch_model(self, device_to=None, patch_weights=True):
|
||||||
for k in self.object_patches:
|
for k in self.object_patches:
|
||||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
@ -191,23 +213,7 @@ class ModelPatcher:
|
|||||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight = model_sd[key]
|
self.patch_weight_to_device(key, device_to)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
|
||||||
|
|
||||||
if key not in self.backup:
|
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
|
||||||
else:
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
||||||
if inplace_update:
|
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
|
||||||
else:
|
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
|
||||||
del temp_weight
|
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
@ -215,6 +221,47 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
||||||
|
self.patch_model(device_to, patch_weights=False)
|
||||||
|
|
||||||
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||||
|
class LowVramPatch:
|
||||||
|
def __init__(self, key, model_patcher):
|
||||||
|
self.key = key
|
||||||
|
self.model_patcher = model_patcher
|
||||||
|
def __call__(self, weight):
|
||||||
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||||
|
|
||||||
|
mem_counter = 0
|
||||||
|
for n, m in self.model.named_modules():
|
||||||
|
lowvram_weight = False
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
module_mem = comfy.model_management.module_size(m)
|
||||||
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
|
lowvram_weight = True
|
||||||
|
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
|
if lowvram_weight:
|
||||||
|
if weight_key in self.patches:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
|
if bias_key in self.patches:
|
||||||
|
m.bias_function = LowVramPatch(weight_key, self)
|
||||||
|
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
else:
|
||||||
|
if hasattr(m, "weight"):
|
||||||
|
self.patch_weight_to_device(weight_key, device_to)
|
||||||
|
self.patch_weight_to_device(bias_key, device_to)
|
||||||
|
m.to(device_to)
|
||||||
|
mem_counter += comfy.model_management.module_size(m)
|
||||||
|
logging.debug("lowvram: loaded module regularly {}".format(m))
|
||||||
|
|
||||||
|
self.model_lowvram = True
|
||||||
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
alpha = p[0]
|
alpha = p[0]
|
||||||
@ -341,6 +388,16 @@ class ModelPatcher:
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None):
|
def unpatch_model(self, device_to=None):
|
||||||
|
if self.model_lowvram:
|
||||||
|
for m in self.model.modules():
|
||||||
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
|
del m.prev_comfy_cast_weights
|
||||||
|
m.weight_function = None
|
||||||
|
m.bias_function = None
|
||||||
|
|
||||||
|
self.model_lowvram = False
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
if self.weight_inplace_update:
|
if self.weight_inplace_update:
|
||||||
|
22
comfy/ops.py
22
comfy/ops.py
@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
if s.bias_function is not None:
|
||||||
|
bias = s.bias_function(bias)
|
||||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
if s.weight_function is not None:
|
||||||
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear):
|
class Linear(torch.nn.Linear):
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -46,6 +53,9 @@ class disable_weight_init:
|
|||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(torch.nn.Conv2d):
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -61,6 +71,9 @@ class disable_weight_init:
|
|||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
class Conv3d(torch.nn.Conv3d):
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -76,6 +89,9 @@ class disable_weight_init:
|
|||||||
|
|
||||||
class GroupNorm(torch.nn.GroupNorm):
|
class GroupNorm(torch.nn.GroupNorm):
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -92,6 +108,9 @@ class disable_weight_init:
|
|||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -111,6 +130,9 @@ class disable_weight_init:
|
|||||||
|
|
||||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user