mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Try to fix memory issue with lora.
This commit is contained in:
parent
67be7eb81d
commit
22f29d66ca
@ -281,19 +281,23 @@ def load_model_gpu(model):
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
real_model = model.model
|
||||
patch_model_to = None
|
||||
if vram_set_state == VRAMState.DISABLED:
|
||||
pass
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(torch_dev)
|
||||
patch_model_to = torch_dev
|
||||
|
||||
try:
|
||||
real_model = model.patch_model()
|
||||
real_model = model.patch_model(device_to=patch_model_to)
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
unload_model()
|
||||
raise e
|
||||
|
||||
if patch_model_to is not None:
|
||||
real_model.to(torch_dev)
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
|
@ -338,7 +338,7 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self):
|
||||
def patch_model(self, device_to=None):
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
@ -350,10 +350,13 @@ class ModelPatcher:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(self.offload_device)
|
||||
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
if device_to is not None:
|
||||
temp_weight = weight.float().to(device_to, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
set_attr(self.model, key, out_weight)
|
||||
del weight
|
||||
del temp_weight
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
|
Loading…
Reference in New Issue
Block a user