Fix potential issues with patching models when saving checkpoints.

This commit is contained in:
comfyanonymous 2023-08-17 10:58:59 -04:00
parent 1498f1a342
commit c28db1f315

View File

@ -574,7 +574,7 @@ class CLIP:
else: else:
self.cond_stage_model.reset_clip_layer() self.cond_stage_model.reset_clip_layer()
model_management.load_model_gpu(self.patcher) self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
if return_pooled: if return_pooled:
return cond, pooled return cond, pooled
@ -590,11 +590,9 @@ class CLIP:
def get_sd(self): def get_sd(self):
return self.cond_stage_model.state_dict() return self.cond_stage_model.state_dict()
def patch_model(self): def load_model(self):
self.patcher.patch_model() model_management.load_model_gpu(self.patcher)
return self.patcher
def unpatch_model(self):
self.patcher.unpatch_model()
def get_key_patches(self): def get_key_patches(self):
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
@ -922,8 +920,8 @@ def load_controlnet(ckpt_path, model=None):
if pth: if pth:
if 'difference' in controlnet_data: if 'difference' in controlnet_data:
if model is not None: if model is not None:
m = model.patch_model() model_management.load_models_gpu([model])
model_sd = m.state_dict() model_sd = model.model_state_dict()
for x in controlnet_data: for x in controlnet_data:
c_m = "control_model." c_m = "control_model."
if x.startswith(c_m): if x.startswith(c_m):
@ -931,7 +929,6 @@ def load_controlnet(ckpt_path, model=None):
if sd_key in model_sd: if sd_key in model_sd:
cd = controlnet_data[x] cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
model.unpatch_model()
else: else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
@ -1279,14 +1276,6 @@ def load_unet(unet_path): #load unet in diffusers format
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
try: model_management.load_models_gpu([model, clip.load_model()])
model.patch_model() sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
clip.patch_model() utils.save_torch_file(sd, output_path, metadata=metadata)
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
utils.save_torch_file(sd, output_path, metadata=metadata)
model.unpatch_model()
clip.unpatch_model()
except Exception as e:
model.unpatch_model()
clip.unpatch_model()
raise e