diff --git a/comfy/sd.py b/comfy/sd.py index 8d8c8ee3f..461c234db 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -574,7 +574,7 @@ class CLIP: else: self.cond_stage_model.reset_clip_layer() - model_management.load_model_gpu(self.patcher) + self.load_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) if return_pooled: return cond, pooled @@ -590,11 +590,9 @@ class CLIP: def get_sd(self): return self.cond_stage_model.state_dict() - def patch_model(self): - self.patcher.patch_model() - - def unpatch_model(self): - self.patcher.unpatch_model() + def load_model(self): + model_management.load_model_gpu(self.patcher) + return self.patcher def get_key_patches(self): return self.patcher.get_key_patches() @@ -922,8 +920,8 @@ def load_controlnet(ckpt_path, model=None): if pth: if 'difference' in controlnet_data: if model is not None: - m = model.patch_model() - model_sd = m.state_dict() + model_management.load_models_gpu([model]) + model_sd = model.model_state_dict() for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): @@ -931,7 +929,6 @@ def load_controlnet(ckpt_path, model=None): if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) - model.unpatch_model() else: print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") @@ -1279,14 +1276,6 @@ def load_unet(unet_path): #load unet in diffusers format return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): - try: - model.patch_model() - clip.patch_model() - sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) - utils.save_torch_file(sd, output_path, metadata=metadata) - model.unpatch_model() - clip.unpatch_model() - except Exception as e: - model.unpatch_model() - clip.unpatch_model() - raise e + model_management.load_models_gpu([model, clip.load_model()]) + sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) + utils.save_torch_file(sd, output_path, metadata=metadata)