From 0e425603fb8ba12f1e7d09a1f58127347a94de98 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Jun 2023 03:25:49 -0400 Subject: [PATCH] Small refactor. --- comfy/sd.py | 21 +++++---------------- comfy/utils.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 336fee4a..04eaaa9f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -31,17 +31,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - keys_to_replace = { - "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", - "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", - "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", - "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", - } - - for x in keys_to_replace: - if x in sd: - sd[keys_to_replace[x]] = sd.pop(x) - sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) for x in load_state_dict_to: @@ -1073,13 +1062,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o "legacy": False } - if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: + if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2: unet_config['use_linear_in_transformer'] = True unet_config["use_fp16"] = fp16 unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] - unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] + unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} @@ -1097,10 +1086,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o else: sd_config["conditioning_key"] = "crossattn" - if unet_config["context_dim"] == 1024: - unet_config["num_head_channels"] = 64 #SD2.x - else: + if unet_config["context_dim"] == 768: unet_config["num_heads"] = 8 #SD1.x + else: + unet_config["num_head_channels"] = 64 #SD2.x unclip = 'model.diffusion_model.label_emb.0.0.weight' if unclip in sd_keys: diff --git a/comfy/utils.py b/comfy/utils.py index 291c62e4..585ebda5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -24,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False): return sd def transformers_convert(sd, prefix_from, prefix_to, number): + keys_to_replace = { + "{}.positional_embedding": "{}.embeddings.position_embedding.weight", + "{}.token_embedding.weight": "{}.embeddings.token_embedding.weight", + "{}.ln_final.weight": "{}.final_layer_norm.weight", + "{}.ln_final.bias": "{}.final_layer_norm.bias", + } + + for k in keys_to_replace: + x = k.format(prefix_from) + if x in sd: + sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) + resblock_to_replace = { "ln_1": "layer_norm1", "ln_2": "layer_norm2",