From 09fdb2b26926ca52d7ffb4038d4655fc4b62db39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 30 Oct 2024 04:24:00 -0400 Subject: [PATCH] Support SD3.5 medium diffusers format weights and loras. --- comfy/model_detection.py | 10 +++++----- comfy/utils.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8435de3e..5e534480 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -540,7 +540,11 @@ def model_config_from_diffusers_unet(state_dict): def convert_diffusers_mmdit(state_dict, output_prefix=""): out_sd = {} - if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux + if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow + num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.') + num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.') + sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix) + elif 'single_transformer_blocks.0.attn.norm_q.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') hidden_size = state_dict["x_embedder.bias"].shape[0] @@ -549,10 +553,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) - elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow - num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.') - num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.') - sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix) else: return None diff --git a/comfy/utils.py b/comfy/utils.py index 06c81107..cc92e111 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -316,10 +316,18 @@ MMDIT_MAP_BLOCK = { ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), + ("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"), + ("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"), ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), ("x_block.attn.proj.bias", "attn.to_out.0.bias"), ("x_block.attn.proj.weight", "attn.to_out.0.weight"), + ("x_block.attn.ln_q.weight", "attn.norm_q.weight"), + ("x_block.attn.ln_k.weight", "attn.norm_k.weight"), + ("x_block.attn2.proj.bias", "attn2.to_out.0.bias"), + ("x_block.attn2.proj.weight", "attn2.to_out.0.weight"), + ("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"), + ("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"), ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), ("x_block.mlp.fc2.bias", "ff.net.2.bias"), @@ -349,6 +357,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) + k = "{}.attn2.".format(block_from) + qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) + for k in MMDIT_MAP_BLOCK: key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])