diff --git a/comfy/sd.py b/comfy/sd.py index 25454f836..d0832ad3a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -28,36 +28,6 @@ import comfy.t2i_adapter.adapter import comfy.supported_models_base import comfy.taesd.taesd -def load_model_weights(model, sd): - m, u = model.load_state_dict(sd, strict=False) - m = set(m) - unexpected_keys = set(u) - - k = list(sd.keys()) - for x in k: - if x not in unexpected_keys: - w = sd.pop(x) - del w - if len(m) > 0: - logging.warning("missing {}".format(m)) - return model - -def load_clip_weights(model, sd): - k = list(sd.keys()) - for x in k: - if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."): - y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.") - sd[y] = sd.pop(x) - - if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: - ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] - if ids.dtype == torch.float32: - sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - - sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.") - return load_model_weights(model, sd) - - def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = {} if model is not None: