mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Don't unload model weights for non weight patches.
This commit is contained in:
parent
150a3e946f
commit
c18a203a8a
@ -273,6 +273,7 @@ class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@ -289,11 +290,13 @@ class LoadedModel:
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
|
||||
try:
|
||||
if lowvram_model_memory > 0:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to)
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
@ -302,11 +305,13 @@ class LoadedModel:
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
|
||||
def model_unload(self):
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
def model_unload(self, unpatch_weights=True):
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
@ -314,15 +319,35 @@ class LoadedModel:
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
|
||||
def unload_model_clones(model):
|
||||
def unload_model_clones(loaded_model, unload_weights_only=True):
|
||||
model = loaded_model.model
|
||||
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {}".format(i))
|
||||
current_loaded_models.pop(i).model_unload()
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
loaded_model.weights_loaded = not unload_weight
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
unloaded_model = False
|
||||
@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0):
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model)
|
||||
unload_model_clones(loaded_model, unload_weights_only=True) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model, unload_weights_only=False) #unload the rest of the clones where the weights can stay loaded
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
@ -25,6 +26,7 @@ class ModelPatcher:
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
@ -39,10 +41,13 @@ class ModelPatcher:
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
n.patches_uuid = self.patches_uuid
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
@ -50,6 +55,19 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if len(self.patches) == 0 and len(clone.patches) == 0:
|
||||
return True
|
||||
|
||||
if self.patches_uuid == clone.patches_uuid:
|
||||
if len(self.patches) != len(clone.patches):
|
||||
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
||||
else:
|
||||
return True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
@ -154,6 +172,7 @@ class ModelPatcher:
|
||||
current_patches.append((strength_patch, patches[k], strength_model))
|
||||
self.patches[k] = current_patches
|
||||
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
@ -387,31 +406,32 @@ class ModelPatcher:
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
if self.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
self.model_lowvram = False
|
||||
self.model_lowvram = False
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
|
Loading…
Reference in New Issue
Block a user