From f81b192944687e223d46fb2fb29ea609a0963e6c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Jul 2023 23:01:28 -0400 Subject: [PATCH] Add logit scale parameter so it's present when saving the checkpoint. --- comfy/sdxl_clip.py | 1 + comfy/supported_models.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index c768b9f9..f676d8c8 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -8,6 +8,7 @@ class SDXLClipG(sd1_clip.SD1ClipModel): super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.layer_norm_hidden_state = False if layer == "last": pass diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6b17b089..38a53ca7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -118,6 +118,7 @@ class SDXLRefiner(supported_models_base.BASE): state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" + keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -153,6 +154,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" + keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)