Support SD3 diffusers lora.

This commit is contained in:
comfyanonymous 2024-06-13 18:26:01 -04:00
parent 37a08a41b3
commit ac151ac169

View File

@ -29,6 +29,7 @@ def load_lora(lora, to_load):
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
@ -40,6 +41,10 @@ def load_lora(lora, to_load):
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
@ -164,6 +169,7 @@ def load_lora(lora, to_load):
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return patch_dict
def model_lora_keys_clip(model, key_map={}):
@ -217,7 +223,8 @@ def model_lora_keys_clip(model, key_map={}):
return key_map
def model_lora_keys_unet(model, key_map={}):
sdk = model.state_dict().keys()
sd = model.state_dict()
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
@ -238,4 +245,17 @@ def model_lora_keys_unet(model, key_map={}):
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
for i in range(model.model_config.unet_config.get("depth", 0)):
k = "transformer.transformer_blocks.{}.attn.".format(i)
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
if qkv in sd:
offset = sd[qkv].shape[0] // 3
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
key_map["{}to_out.0".format(k)] = proj
return key_map