mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Support SD3 diffusers lora.
This commit is contained in:
parent
37a08a41b3
commit
ac151ac169
@ -29,6 +29,7 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
A_name = None
|
A_name = None
|
||||||
|
|
||||||
@ -40,6 +41,10 @@ def load_lora(lora, to_load):
|
|||||||
A_name = diffusers_lora
|
A_name = diffusers_lora
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
mid_name = None
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
elif transformers_lora in lora.keys():
|
elif transformers_lora in lora.keys():
|
||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
@ -164,6 +169,7 @@ def load_lora(lora, to_load):
|
|||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
logging.warning("lora key not loaded: {}".format(x))
|
logging.warning("lora key not loaded: {}".format(x))
|
||||||
|
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
@ -217,7 +223,8 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sd = model.state_dict()
|
||||||
|
sdk = sd.keys()
|
||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
@ -238,4 +245,17 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if diffusers_lora_key.endswith(".to_out.0"):
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
key_map[diffusers_lora_key] = unet_key
|
key_map[diffusers_lora_key] = unet_key
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||||
|
for i in range(model.model_config.unet_config.get("depth", 0)):
|
||||||
|
k = "transformer.transformer_blocks.{}.attn.".format(i)
|
||||||
|
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
|
||||||
|
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
|
||||||
|
if qkv in sd:
|
||||||
|
offset = sd[qkv].shape[0] // 3
|
||||||
|
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
|
||||||
|
key_map["{}to_out.0".format(k)] = proj
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
Loading…
Reference in New Issue
Block a user