mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
ea77750759
This is a format with keys like: text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.lora_up.weight Instead of waiting for me to add support for specific lora formats you can convert your text encoder loras to this format instead. If you want to see an example save a text encoder lora with the SaveLora node with the commit right after this one.
581 lines
26 KiB
Python
581 lines
26 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import comfy.utils
|
|
import comfy.model_management
|
|
import comfy.model_base
|
|
import logging
|
|
import torch
|
|
|
|
LORA_CLIP_MAP = {
|
|
"mlp.fc1": "mlp_fc1",
|
|
"mlp.fc2": "mlp_fc2",
|
|
"self_attn.k_proj": "self_attn_k_proj",
|
|
"self_attn.q_proj": "self_attn_q_proj",
|
|
"self_attn.v_proj": "self_attn_v_proj",
|
|
"self_attn.out_proj": "self_attn_out_proj",
|
|
}
|
|
|
|
|
|
def load_lora(lora, to_load):
|
|
patch_dict = {}
|
|
loaded_keys = set()
|
|
for x in to_load:
|
|
alpha_name = "{}.alpha".format(x)
|
|
alpha = None
|
|
if alpha_name in lora.keys():
|
|
alpha = lora[alpha_name].item()
|
|
loaded_keys.add(alpha_name)
|
|
|
|
dora_scale_name = "{}.dora_scale".format(x)
|
|
dora_scale = None
|
|
if dora_scale_name in lora.keys():
|
|
dora_scale = lora[dora_scale_name]
|
|
loaded_keys.add(dora_scale_name)
|
|
|
|
regular_lora = "{}.lora_up.weight".format(x)
|
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
A_name = None
|
|
|
|
if regular_lora in lora.keys():
|
|
A_name = regular_lora
|
|
B_name = "{}.lora_down.weight".format(x)
|
|
mid_name = "{}.lora_mid.weight".format(x)
|
|
elif diffusers_lora in lora.keys():
|
|
A_name = diffusers_lora
|
|
B_name = "{}_lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers2_lora in lora.keys():
|
|
A_name = diffusers2_lora
|
|
B_name = "{}.lora_A.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers3_lora in lora.keys():
|
|
A_name = diffusers3_lora
|
|
B_name = "{}.lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif transformers_lora in lora.keys():
|
|
A_name = transformers_lora
|
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
mid_name = None
|
|
|
|
if A_name is not None:
|
|
mid = None
|
|
if mid_name is not None and mid_name in lora.keys():
|
|
mid = lora[mid_name]
|
|
loaded_keys.add(mid_name)
|
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
|
loaded_keys.add(A_name)
|
|
loaded_keys.add(B_name)
|
|
|
|
|
|
######## loha
|
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
hada_t1_name = "{}.hada_t1".format(x)
|
|
hada_t2_name = "{}.hada_t2".format(x)
|
|
if hada_w1_a_name in lora.keys():
|
|
hada_t1 = None
|
|
hada_t2 = None
|
|
if hada_t1_name in lora.keys():
|
|
hada_t1 = lora[hada_t1_name]
|
|
hada_t2 = lora[hada_t2_name]
|
|
loaded_keys.add(hada_t1_name)
|
|
loaded_keys.add(hada_t2_name)
|
|
|
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
loaded_keys.add(hada_w1_a_name)
|
|
loaded_keys.add(hada_w1_b_name)
|
|
loaded_keys.add(hada_w2_a_name)
|
|
loaded_keys.add(hada_w2_b_name)
|
|
|
|
|
|
######## lokr
|
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
|
|
lokr_w1 = None
|
|
if lokr_w1_name in lora.keys():
|
|
lokr_w1 = lora[lokr_w1_name]
|
|
loaded_keys.add(lokr_w1_name)
|
|
|
|
lokr_w2 = None
|
|
if lokr_w2_name in lora.keys():
|
|
lokr_w2 = lora[lokr_w2_name]
|
|
loaded_keys.add(lokr_w2_name)
|
|
|
|
lokr_w1_a = None
|
|
if lokr_w1_a_name in lora.keys():
|
|
lokr_w1_a = lora[lokr_w1_a_name]
|
|
loaded_keys.add(lokr_w1_a_name)
|
|
|
|
lokr_w1_b = None
|
|
if lokr_w1_b_name in lora.keys():
|
|
lokr_w1_b = lora[lokr_w1_b_name]
|
|
loaded_keys.add(lokr_w1_b_name)
|
|
|
|
lokr_w2_a = None
|
|
if lokr_w2_a_name in lora.keys():
|
|
lokr_w2_a = lora[lokr_w2_a_name]
|
|
loaded_keys.add(lokr_w2_a_name)
|
|
|
|
lokr_w2_b = None
|
|
if lokr_w2_b_name in lora.keys():
|
|
lokr_w2_b = lora[lokr_w2_b_name]
|
|
loaded_keys.add(lokr_w2_b_name)
|
|
|
|
lokr_t2 = None
|
|
if lokr_t2_name in lora.keys():
|
|
lokr_t2 = lora[lokr_t2_name]
|
|
loaded_keys.add(lokr_t2_name)
|
|
|
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
|
|
#glora
|
|
a1_name = "{}.a1.weight".format(x)
|
|
a2_name = "{}.a2.weight".format(x)
|
|
b1_name = "{}.b1.weight".format(x)
|
|
b2_name = "{}.b2.weight".format(x)
|
|
if a1_name in lora:
|
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
loaded_keys.add(a1_name)
|
|
loaded_keys.add(a2_name)
|
|
loaded_keys.add(b1_name)
|
|
loaded_keys.add(b2_name)
|
|
|
|
w_norm_name = "{}.w_norm".format(x)
|
|
b_norm_name = "{}.b_norm".format(x)
|
|
w_norm = lora.get(w_norm_name, None)
|
|
b_norm = lora.get(b_norm_name, None)
|
|
|
|
if w_norm is not None:
|
|
loaded_keys.add(w_norm_name)
|
|
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
|
if b_norm is not None:
|
|
loaded_keys.add(b_norm_name)
|
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
|
|
|
diff_name = "{}.diff".format(x)
|
|
diff_weight = lora.get(diff_name, None)
|
|
if diff_weight is not None:
|
|
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
|
loaded_keys.add(diff_name)
|
|
|
|
diff_bias_name = "{}.diff_b".format(x)
|
|
diff_bias = lora.get(diff_bias_name, None)
|
|
if diff_bias is not None:
|
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
|
loaded_keys.add(diff_bias_name)
|
|
|
|
for x in lora.keys():
|
|
if x not in loaded_keys:
|
|
logging.warning("lora key not loaded: {}".format(x))
|
|
|
|
return patch_dict
|
|
|
|
def model_lora_keys_clip(model, key_map={}):
|
|
sdk = model.state_dict().keys()
|
|
for k in sdk:
|
|
if k.endswith(".weight"):
|
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
|
|
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
|
clip_l_present = False
|
|
for b in range(32): #TODO: clean up
|
|
for c in LORA_CLIP_MAP:
|
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
key_map[lora_key] = k
|
|
clip_l_present = True
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
if clip_l_present:
|
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
else:
|
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
|
key_map[lora_key] = k
|
|
|
|
for k in sdk:
|
|
if k.endswith(".weight"):
|
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
|
key_map[lora_key] = k
|
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
|
key_map[lora_key] = k
|
|
|
|
|
|
k = "clip_g.transformer.text_projection.weight"
|
|
if k in sdk:
|
|
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
|
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
|
|
|
k = "clip_l.transformer.text_projection.weight"
|
|
if k in sdk:
|
|
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
|
|
|
return key_map
|
|
|
|
def model_lora_keys_unet(model, key_map={}):
|
|
sd = model.state_dict()
|
|
sdk = sd.keys()
|
|
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
|
key_map["lora_unet_{}".format(key_lora)] = k
|
|
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
|
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
|
|
|
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
|
|
|
diffusers_lora_prefix = ["", "unet."]
|
|
for p in diffusers_lora_prefix:
|
|
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
|
if diffusers_lora_key.endswith(".to_out.0"):
|
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
|
key_map[diffusers_lora_key] = unet_key
|
|
|
|
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
|
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.HunyuanDiT):
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
|
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
|
|
|
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
|
|
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
|
|
|
return key_map
|
|
|
|
|
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
lora_diff *= alpha
|
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
|
weight_norm = (
|
|
weight_calc.transpose(0, 1)
|
|
.reshape(weight_calc.shape[1], -1)
|
|
.norm(dim=1, keepdim=True)
|
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
.transpose(0, 1)
|
|
)
|
|
|
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
if strength != 1.0:
|
|
weight_calc -= weight
|
|
weight += strength * (weight_calc)
|
|
else:
|
|
weight[:] = weight_calc
|
|
return weight
|
|
|
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
|
"""
|
|
Pad a tensor to a new shape with zeros.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The original tensor to be padded.
|
|
new_shape (List[int]): The desired shape of the padded tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
|
|
|
Note:
|
|
If the new shape is smaller than the original tensor in any dimension,
|
|
the original tensor will be truncated in that dimension.
|
|
"""
|
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
|
|
|
if len(new_shape) != len(tensor.shape):
|
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
|
|
|
# Create a new tensor filled with zeros
|
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
|
|
|
# Create slicing tuples for both tensors
|
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
|
|
# Copy the original tensor into the new tensor
|
|
padded_tensor[new_slices] = tensor[orig_slices]
|
|
|
|
return padded_tensor
|
|
|
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|
for p in patches:
|
|
strength = p[0]
|
|
v = p[1]
|
|
strength_model = p[2]
|
|
offset = p[3]
|
|
function = p[4]
|
|
if function is None:
|
|
function = lambda a: a
|
|
|
|
old_weight = None
|
|
if offset is not None:
|
|
old_weight = weight
|
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
|
|
if strength_model != 1.0:
|
|
weight *= strength_model
|
|
|
|
if isinstance(v, list):
|
|
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
|
|
|
if len(v) == 1:
|
|
patch_type = "diff"
|
|
elif len(v) == 2:
|
|
patch_type = v[0]
|
|
v = v[1]
|
|
|
|
if patch_type == "diff":
|
|
diff: torch.Tensor = v[0]
|
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
|
if do_pad_weight and diff.shape != weight.shape:
|
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
|
|
|
if strength != 0.0:
|
|
if diff.shape != weight.shape:
|
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
|
else:
|
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
|
elif patch_type == "lora": #lora/locon
|
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
dora_scale = v[4]
|
|
if v[2] is not None:
|
|
alpha = v[2] / mat2.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
if v[3] is not None:
|
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
try:
|
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
elif patch_type == "lokr":
|
|
w1 = v[0]
|
|
w2 = v[1]
|
|
w1_a = v[3]
|
|
w1_b = v[4]
|
|
w2_a = v[5]
|
|
w2_b = v[6]
|
|
t2 = v[7]
|
|
dora_scale = v[8]
|
|
dim = None
|
|
|
|
if w1 is None:
|
|
dim = w1_b.shape[0]
|
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
else:
|
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
|
|
if w2 is None:
|
|
dim = w2_b.shape[0]
|
|
if t2 is None:
|
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
else:
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
else:
|
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
if v[2] is not None and dim is not None:
|
|
alpha = v[2] / dim
|
|
else:
|
|
alpha = 1.0
|
|
|
|
try:
|
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
elif patch_type == "loha":
|
|
w1a = v[0]
|
|
w1b = v[1]
|
|
if v[2] is not None:
|
|
alpha = v[2] / w1b.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
w2a = v[3]
|
|
w2b = v[4]
|
|
dora_scale = v[7]
|
|
if v[5] is not None: #cp decomposition
|
|
t1 = v[5]
|
|
t2 = v[6]
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
else:
|
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
|
|
try:
|
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
elif patch_type == "glora":
|
|
dora_scale = v[5]
|
|
|
|
old_glora = False
|
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
rank = v[0].shape[0]
|
|
old_glora = True
|
|
|
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
pass
|
|
else:
|
|
old_glora = False
|
|
rank = v[1].shape[0]
|
|
|
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
|
|
if v[4] is not None:
|
|
alpha = v[4] / rank
|
|
else:
|
|
alpha = 1.0
|
|
|
|
try:
|
|
if old_glora:
|
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
else:
|
|
if weight.dim() > 2:
|
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
else:
|
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
else:
|
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
|
|
|
if old_weight is not None:
|
|
weight = old_weight
|
|
|
|
return weight
|