Properly support SDXL diffusers unet with UNETLoader node.

This commit is contained in:
comfyanonymous 2023-07-21 14:38:56 -04:00
parent 0115018695
commit 58b2364f58
2 changed files with 23 additions and 16 deletions

View File

@ -1139,12 +1139,14 @@ def load_unet(unet_path): #load unet in diffusers format
fp16 = model_management.should_use_fp16(model_params=parameters)
match = {}
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["context_dim"] = sd["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["model_channels"] = sd["conv_in.weight"].shape[0]
match["in_channels"] = sd["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
@ -1198,6 +1200,7 @@ def load_unet(unet_path): #load unet in diffusers format
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
print("ERROR UNSUPPORTED UNET", unet_path)
def save_checkpoint(output_path, model, clip, vae, metadata=None):
try:

View File

@ -120,20 +120,24 @@ UNET_MAP_RESNET = {
}
UNET_MAP_BASIC = {
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
"input_blocks.0.0.weight": "conv_in.weight",
"input_blocks.0.0.bias": "conv_in.bias",
"out.0.weight": "conv_norm_out.weight",
"out.0.bias": "conv_norm_out.bias",
"out.2.weight": "conv_out.weight",
"out.2.bias": "conv_out.bias",
"time_embed.0.weight": "time_embedding.linear_1.weight",
"time_embed.0.bias": "time_embedding.linear_1.bias",
"time_embed.2.weight": "time_embedding.linear_2.weight",
"time_embed.2.bias": "time_embedding.linear_2.bias"
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias")
}
def unet_to_diffusers(unet_config):
@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config):
n += 1
for k in UNET_MAP_BASIC:
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
diffusers_unet_map[k[1]] = k[0]
return diffusers_unet_map