mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Small refactor.
This commit is contained in:
parent
a3a713b6c5
commit
0e425603fb
21
comfy/sd.py
21
comfy/sd.py
@ -31,17 +31,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||
if ids.dtype == torch.float32:
|
||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
keys_to_replace = {
|
||||
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for x in keys_to_replace:
|
||||
if x in sd:
|
||||
sd[keys_to_replace[x]] = sd.pop(x)
|
||||
|
||||
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
|
||||
|
||||
for x in load_state_dict_to:
|
||||
@ -1073,13 +1062,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
"legacy": False
|
||||
}
|
||||
|
||||
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2:
|
||||
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
|
||||
unet_config['use_linear_in_transformer'] = True
|
||||
|
||||
unet_config["use_fp16"] = fp16
|
||||
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
|
||||
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
|
||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
@ -1097,10 +1086,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
else:
|
||||
sd_config["conditioning_key"] = "crossattn"
|
||||
|
||||
if unet_config["context_dim"] == 1024:
|
||||
unet_config["num_head_channels"] = 64 #SD2.x
|
||||
else:
|
||||
if unet_config["context_dim"] == 768:
|
||||
unet_config["num_heads"] = 8 #SD1.x
|
||||
else:
|
||||
unet_config["num_head_channels"] = 64 #SD2.x
|
||||
|
||||
unclip = 'model.diffusion_model.label_emb.0.0.weight'
|
||||
if unclip in sd_keys:
|
||||
|
@ -24,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False):
|
||||
return sd
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
|
||||
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
|
||||
"{}.ln_final.weight": "{}.final_layer_norm.weight",
|
||||
"{}.ln_final.bias": "{}.final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
x = k.format(prefix_from)
|
||||
if x in sd:
|
||||
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
||||
|
||||
resblock_to_replace = {
|
||||
"ln_1": "layer_norm1",
|
||||
"ln_2": "layer_norm2",
|
||||
|
Loading…
Reference in New Issue
Block a user