From 1e0fcc9a658dac305660c982a6bc0ea9b5657cf7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 02:07:40 -0500 Subject: [PATCH] Make XL checkpoints save in a more standard format. --- comfy/supported_models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index dbc3cf26e..5d57a31a1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -190,12 +190,16 @@ class SDXL(supported_models_base.BASE): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] + state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1)) + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict_g: + state_dict_g.pop(p) + replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)