diff --git a/comfy/lora.py b/comfy/lora.py index 096285bba..37254b03f 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -29,6 +29,7 @@ def load_lora(lora, to_load): regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) A_name = None @@ -40,6 +41,10 @@ def load_lora(lora, to_load): A_name = diffusers_lora B_name = "{}_lora.down.weight".format(x) mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None elif transformers_lora in lora.keys(): A_name = transformers_lora B_name ="{}.lora_linear_layer.down.weight".format(x) @@ -164,6 +169,7 @@ def load_lora(lora, to_load): for x in lora.keys(): if x not in loaded_keys: logging.warning("lora key not loaded: {}".format(x)) + return patch_dict def model_lora_keys_clip(model, key_map={}): @@ -217,7 +223,8 @@ def model_lora_keys_clip(model, key_map={}): return key_map def model_lora_keys_unet(model, key_map={}): - sdk = model.state_dict().keys() + sd = model.state_dict() + sdk = sd.keys() for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): @@ -238,4 +245,17 @@ def model_lora_keys_unet(model, key_map={}): if diffusers_lora_key.endswith(".to_out.0"): diffusers_lora_key = diffusers_lora_key[:-2] key_map[diffusers_lora_key] = unet_key + + if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3 + for i in range(model.model_config.unet_config.get("depth", 0)): + k = "transformer.transformer_blocks.{}.attn.".format(i) + qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i) + proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i) + if qkv in sd: + offset = sd[qkv].shape[0] // 3 + key_map["{}to_q".format(k)] = (qkv, (0, 0, offset)) + key_map["{}to_k".format(k)] = (qkv, (0, offset, offset)) + key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset)) + key_map["{}to_out.0".format(k)] = proj + return key_map