mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
Merge e8f3bc5ab7
into 22ad513c72
This commit is contained in:
commit
b8bac6558a
321
comfy/lora.py
321
comfy/lora.py
@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
|
import comfy.weight_adapter as weight_adapter
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
|
|||||||
dora_scale = lora[dora_scale_name]
|
dora_scale = lora[dora_scale_name]
|
||||||
loaded_keys.add(dora_scale_name)
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
reshape_name = "{}.reshape_weight".format(x)
|
for adapter_cls in weight_adapter.adapters:
|
||||||
reshape = None
|
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
||||||
if reshape_name in lora.keys():
|
if adapter is not None:
|
||||||
try:
|
patch_dict[to_load[x]] = adapter
|
||||||
reshape = lora[reshape_name].tolist()
|
loaded_keys.update(adapter.loaded_keys)
|
||||||
loaded_keys.add(reshape_name)
|
continue
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
||||||
mochi_lora = "{}.lora_B".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers2_lora in lora.keys():
|
|
||||||
A_name = diffusers2_lora
|
|
||||||
B_name = "{}.lora_A.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers3_lora in lora.keys():
|
|
||||||
A_name = diffusers3_lora
|
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif mochi_lora in lora.keys():
|
|
||||||
A_name = mochi_lora
|
|
||||||
B_name = "{}.lora_A".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## loha
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## lokr
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
||||||
|
|
||||||
#glora
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
@ -408,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad a tensor to a new shape with zeros.
|
Pad a tensor to a new shape with zeros.
|
||||||
@ -482,6 +336,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||||
|
if output is None:
|
||||||
|
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||||
|
else:
|
||||||
|
weight = output
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
continue
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
@ -508,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
reshape = v[5]
|
|
||||||
|
|
||||||
if reshape is not None:
|
|
||||||
weight = pad_tensor_to_shape(weight, reshape)
|
|
||||||
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
old_glora = False
|
|
||||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
||||||
rank = v[0].shape[0]
|
|
||||||
old_glora = True
|
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
old_glora = False
|
|
||||||
rank = v[1].shape[0]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / rank
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if old_glora:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
||||||
else:
|
|
||||||
if weight.dim() > 2:
|
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
else:
|
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
||||||
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
13
comfy/weight_adapter/__init__.py
Normal file
13
comfy/weight_adapter/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from .base import WeightAdapterBase
|
||||||
|
from .lora import LoRAAdapter
|
||||||
|
from .loha import LoHaAdapter
|
||||||
|
from .lokr import LoKrAdapter
|
||||||
|
from .glora import GLoRAAdapter
|
||||||
|
|
||||||
|
|
||||||
|
adapters: list[type[WeightAdapterBase]] = [
|
||||||
|
LoRAAdapter,
|
||||||
|
LoHaAdapter,
|
||||||
|
LoKrAdapter,
|
||||||
|
GLoRAAdapter,
|
||||||
|
]
|
94
comfy/weight_adapter/base.py
Normal file
94
comfy/weight_adapter/base.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterBase:
|
||||||
|
name: str
|
||||||
|
loaded_keys: set[str]
|
||||||
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# [TODO] Collaborate with LoRA training PR #7032
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
93
comfy/weight_adapter/glora.py
Normal file
93
comfy/weight_adapter/glora.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class GLoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "glora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["GLoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
100
comfy/weight_adapter/loha.py
Normal file
100
comfy/weight_adapter/loha.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
|
name = "loha"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoHaAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
133
comfy/weight_adapter/lokr.py
Normal file
133
comfy/weight_adapter/lokr.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
|
name = "lokr"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoKrAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
142
comfy/weight_adapter/lora.py
Normal file
142
comfy/weight_adapter/lora.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "lora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
reshape_name = "{}.reshape_weight".format(x)
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif mochi_lora in lora.keys():
|
||||||
|
A_name = mochi_lora
|
||||||
|
B_name = "{}.lora_A".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
reshape = None
|
||||||
|
if reshape_name in lora.keys():
|
||||||
|
try:
|
||||||
|
reshape = lora[reshape_name].tolist()
|
||||||
|
loaded_keys.add(reshape_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
mat1 = comfy.model_management.cast_to_device(
|
||||||
|
v[0], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(
|
||||||
|
v[1], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
dora_scale = v[4]
|
||||||
|
reshape = v[5]
|
||||||
|
|
||||||
|
if reshape is not None:
|
||||||
|
weight = pad_tensor_to_shape(weight, reshape)
|
||||||
|
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(
|
||||||
|
v[3], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = (
|
||||||
|
torch.mm(
|
||||||
|
mat2.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
mat3.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
)
|
||||||
|
.reshape(final_shape)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(
|
||||||
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||||
|
).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
Loading…
Reference in New Issue
Block a user