diff --git a/comfy/sd.py b/comfy/sd.py index ac13d8bc9..bf935a5f9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -99,7 +99,7 @@ LORA_CLIP_MAP = { "self_attn.out_proj": "self_attn_out_proj", } -LORA_UNET_MAP = { +LORA_UNET_MAP_ATTENTIONS = { "proj_in": "proj_in", "proj_out": "proj_out", "transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q", @@ -114,6 +114,12 @@ LORA_UNET_MAP = { "transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2", } +LORA_UNET_MAP_RESNET = { + "in_layers.2": "resnets_{}_conv1", + "emb_layers.1": "resnets_{}_time_emb_proj", + "out_layers.3": "resnets_{}_conv2", + "skip_connection": "resnets_{}_conv_shortcut" +} def load_lora(path, to_load): lora = load_torch_file(path) @@ -143,27 +149,27 @@ def model_lora_keys(model, key_map={}): for b in range(12): tk = "model.diffusion_model.input_blocks.{}.1".format(b) up_counter = 0 - for c in LORA_UNET_MAP: + for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) if k in sdk: - lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c]) + lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k up_counter += 1 if up_counter >= 4: counter += 1 - for c in LORA_UNET_MAP: + for c in LORA_UNET_MAP_ATTENTIONS: k = "model.diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: - lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c]) + lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k counter = 3 for b in range(12): tk = "model.diffusion_model.output_blocks.{}.1".format(b) up_counter = 0 - for c in LORA_UNET_MAP: + for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) if k in sdk: - lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c]) + lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k up_counter += 1 if up_counter >= 4: @@ -177,6 +183,61 @@ def model_lora_keys(model, key_map={}): lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k + + #Locon stuff + ds_counter = 0 + counter = 0 + for b in range(12): + tk = "model.diffusion_model.input_blocks.{}.0".format(b) + key_in = False + for c in LORA_UNET_MAP_RESNET: + k = "{}.{}.weight".format(tk, c) + if k in sdk: + lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2)) + key_map[lora_key] = k + key_in = True + for bb in range(3): + k = "{}.{}.op.weight".format(tk[:-2], bb) + if k in sdk: + lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter) + key_map[lora_key] = k + ds_counter += 1 + if key_in: + counter += 1 + + counter = 0 + for b in range(3): + tk = "model.diffusion_model.middle_block.{}".format(b) + key_in = False + for c in LORA_UNET_MAP_RESNET: + k = "{}.{}.weight".format(tk, c) + if k in sdk: + lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter)) + key_map[lora_key] = k + key_in = True + if key_in: + counter += 1 + + counter = 0 + us_counter = 0 + for b in range(12): + tk = "model.diffusion_model.output_blocks.{}.0".format(b) + key_in = False + for c in LORA_UNET_MAP_RESNET: + k = "{}.{}.weight".format(tk, c) + if k in sdk: + lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3)) + key_map[lora_key] = k + key_in = True + for bb in range(3): + k = "{}.{}.conv.weight".format(tk[:-2], bb) + if k in sdk: + lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter) + key_map[lora_key] = k + us_counter += 1 + if key_in: + counter += 1 + return key_map class ModelPatcher: