From d1cdf51e1b6929686e391d9245c9c040714739d9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 1 Oct 2024 07:08:41 -0400 Subject: [PATCH] Refactor some of the TE detection code. --- comfy/sd.py | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 99859d24..c6166124 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -406,6 +406,32 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) + +class TEModel(Enum): + CLIP_L = 1 + CLIP_H = 2 + CLIP_G = 3 + T5_XXL = 4 + T5_XL = 5 + T5_BASE = 6 + +def detect_te_model(sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return TEModel.CLIP_G + if "text_model.encoder.layers.22.mlp.fc1.weight" in sd: + return TEModel.CLIP_H + if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: + return TEModel.CLIP_L + if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: + weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] + if weight.shape[-1] == 4096: + return TEModel.T5_XXL + elif weight.shape[-1] == 2048: + return TEModel.T5_XL + if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: + return TEModel.T5_BASE + return None + def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = state_dicts class EmptyClass: @@ -421,30 +447,29 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target = EmptyClass() clip_target.params = {} if len(clip_data) == 1: - if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: + te_model = detect_te_model(clip_data[0]) + if te_model == TEModel.CLIP_G: if clip_type == CLIPType.STABLE_CASCADE: clip_target.clip = sdxl_clip.StableCascadeClipModel clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer - elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: + elif te_model == TEModel.CLIP_H: clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer - elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: + elif te_model == TEModel.T5_XXL: weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] dtype_t5 = weight.dtype - if weight.shape[-1] == 4096: - clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) - clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer - elif weight.shape[-1] == 2048: - clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model - clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer - elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]: + clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) + clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif te_model == TEModel.T5_XL: + clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model + clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer + elif te_model == TEModel.T5_BASE: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer else: - w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None) clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: