diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4843e6a4..e09dd381 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,7 +1,9 @@ import comfy.supported_models import comfy.supported_models_base +import comfy.utils import math import logging +import torch def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict): if unet_config is not None: return model_config_from_unet_config(unet_config) return None + +def convert_diffusers_mmdit(state_dict, output_prefix=""): + depth = count_blocks(state_dict, 'transformer_blocks.{}.') + if depth > 0: + out_sd = {} + sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix) + for k in sd_map: + weight = state_dict.get(k, None) + if weight is not None: + t = sd_map[k] + + if not isinstance(t, str): + if len(t) > 2: + fun = t[2] + else: + fun = lambda a: a + offset = t[1] + if offset is not None: + old_weight = out_sd.get(t[0], None) + if old_weight is None: + old_weight = torch.empty_like(weight) + old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) + + w = old_weight.narrow(offset[0], offset[1], offset[2]) + else: + old_weight = weight + w = weight + w[:] = fun(weight) + t = t[0] + out_sd[t] = old_weight + else: + out_sd[t] = weight + state_dict.pop(k) + + return out_sd diff --git a/comfy/sd.py b/comfy/sd.py index cfbf8fa4..178f52e8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -568,7 +568,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + if 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 + new_sd = model_detection.convert_diffusers_mmdit(sd, "") + if new_sd is None: + return None + model_config = model_detection.model_config_from_unet(new_sd, "") + if model_config is None: + return None + elif "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade model_config = model_detection.model_config_from_unet(sd, "") if model_config is None: return None diff --git a/comfy/utils.py b/comfy/utils.py index 877f14b0..f0d67500 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + MMDIT_MAP_BASIC = { ("context_embedder.bias", "context_embedder.bias"), ("context_embedder.weight", "context_embedder.weight"), @@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = { ("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"), ("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"), ("pos_embed", "pos_embed.pos_embed"), - ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"), - ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"), + ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), + ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), ("final_layer.linear.bias", "proj_out.bias"), ("final_layer.linear.weight", "proj_out.weight"), } @@ -313,8 +318,15 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): for k in MMDIT_MAP_BLOCK: key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) - for k in MMDIT_MAP_BASIC: - key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + map_basic = MMDIT_MAP_BASIC.copy() + map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift)) + map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift)) + + for k in map_basic: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) return key_map diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index f2d008d8..74ab6814 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -52,9 +52,32 @@ class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeSD3(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["pos_embed."] = argument + arg_dict["x_embedder."] = argument + arg_dict["context_embedder."] = argument + arg_dict["y_embedder."] = argument + arg_dict["t_embedder."] = argument + + for i in range(38): + arg_dict["joint_blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSDXL": ModelMergeSDXL, + "ModelMergeSD3": ModelMergeSD3, }