mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Fix potential issues with patching models when saving checkpoints.
This commit is contained in:
parent
1498f1a342
commit
c28db1f315
25
comfy/sd.py
25
comfy/sd.py
@ -574,7 +574,7 @@ class CLIP:
|
|||||||
else:
|
else:
|
||||||
self.cond_stage_model.reset_clip_layer()
|
self.cond_stage_model.reset_clip_layer()
|
||||||
|
|
||||||
model_management.load_model_gpu(self.patcher)
|
self.load_model()
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
if return_pooled:
|
if return_pooled:
|
||||||
return cond, pooled
|
return cond, pooled
|
||||||
@ -590,11 +590,9 @@ class CLIP:
|
|||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.cond_stage_model.state_dict()
|
return self.cond_stage_model.state_dict()
|
||||||
|
|
||||||
def patch_model(self):
|
def load_model(self):
|
||||||
self.patcher.patch_model()
|
model_management.load_model_gpu(self.patcher)
|
||||||
|
return self.patcher
|
||||||
def unpatch_model(self):
|
|
||||||
self.patcher.unpatch_model()
|
|
||||||
|
|
||||||
def get_key_patches(self):
|
def get_key_patches(self):
|
||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
@ -922,8 +920,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if pth:
|
if pth:
|
||||||
if 'difference' in controlnet_data:
|
if 'difference' in controlnet_data:
|
||||||
if model is not None:
|
if model is not None:
|
||||||
m = model.patch_model()
|
model_management.load_models_gpu([model])
|
||||||
model_sd = m.state_dict()
|
model_sd = model.model_state_dict()
|
||||||
for x in controlnet_data:
|
for x in controlnet_data:
|
||||||
c_m = "control_model."
|
c_m = "control_model."
|
||||||
if x.startswith(c_m):
|
if x.startswith(c_m):
|
||||||
@ -931,7 +929,6 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if sd_key in model_sd:
|
if sd_key in model_sd:
|
||||||
cd = controlnet_data[x]
|
cd = controlnet_data[x]
|
||||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||||
model.unpatch_model()
|
|
||||||
else:
|
else:
|
||||||
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||||
|
|
||||||
@ -1279,14 +1276,6 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
try:
|
model_management.load_models_gpu([model, clip.load_model()])
|
||||||
model.patch_model()
|
|
||||||
clip.patch_model()
|
|
||||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
model.unpatch_model()
|
|
||||||
clip.unpatch_model()
|
|
||||||
except Exception as e:
|
|
||||||
model.unpatch_model()
|
|
||||||
clip.unpatch_model()
|
|
||||||
raise e
|
|
||||||
|
Loading…
Reference in New Issue
Block a user