Try to fix memory issue with lora.

This commit is contained in:
comfyanonymous 2023-07-22 21:26:45 -04:00
parent 67be7eb81d
commit 22f29d66ca
2 changed files with 12 additions and 5 deletions

View File

@ -281,19 +281,23 @@ def load_model_gpu(model):
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
real_model = model.model real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED: if vram_set_state == VRAMState.DISABLED:
pass pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False model_accelerated = False
real_model.to(torch_dev) patch_model_to = torch_dev
try: try:
real_model = model.patch_model() real_model = model.patch_model(device_to=patch_model_to)
except Exception as e: except Exception as e:
model.unpatch_model() model.unpatch_model()
unload_model() unload_model()
raise e raise e
if patch_model_to is not None:
real_model.to(torch_dev)
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)

View File

@ -338,7 +338,7 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_model(self): def patch_model(self, device_to=None):
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for key in self.patches: for key in self.patches:
if key not in model_sd: if key not in model_sd:
@ -350,10 +350,13 @@ class ModelPatcher:
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(self.offload_device) self.backup[key] = weight.to(self.offload_device)
temp_weight = weight.to(torch.float32, copy=True) if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight) set_attr(self.model, key, out_weight)
del weight del temp_weight
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):