diff --git a/comfy/sd.py b/comfy/sd.py index e5baddcac..0b55b8b14 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -26,12 +26,7 @@ def load_torch_file(ckpt): sd = pl_sd return sd -def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): - print(f"Loading model from {ckpt}") - - sd = load_torch_file(ckpt) - model = instantiate_from_config(config.model) - +def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) k = list(sd.keys()) @@ -654,5 +649,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] - model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to) + model = instantiate_from_config(config.model) + sd = load_torch_file(ckpt_path) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) return (ModelPatcher(model), clip, vae)