Small refactor.

This commit is contained in:
comfyanonymous 2023-06-06 03:25:49 -04:00
parent a3a713b6c5
commit 0e425603fb
2 changed files with 17 additions and 16 deletions

View File

@ -31,17 +31,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
keys_to_replace = {
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
}
for x in keys_to_replace:
if x in sd:
sd[keys_to_replace[x]] = sd.pop(x)
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
for x in load_state_dict_to:
@ -1073,13 +1062,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
"legacy": False
}
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2:
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
unet_config['use_linear_in_transformer'] = True
unet_config["use_fp16"] = fp16
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
@ -1097,10 +1086,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
else:
sd_config["conditioning_key"] = "crossattn"
if unet_config["context_dim"] == 1024:
unet_config["num_head_channels"] = 64 #SD2.x
else:
if unet_config["context_dim"] == 768:
unet_config["num_heads"] = 8 #SD1.x
else:
unet_config["num_head_channels"] = 64 #SD2.x
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:

View File

@ -24,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False):
return sd
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias",
}
for k in keys_to_replace:
x = k.format(prefix_from)
if x in sd:
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",