mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import comfy.model_management
|
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
|
|
|
|
|
class LoRAAdapter(WeightAdapterBase):
|
|
name = "lora"
|
|
|
|
def __init__(self, loaded_keys, weights):
|
|
self.loaded_keys = loaded_keys
|
|
self.weights = weights
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
loaded_keys: set[str] = None,
|
|
) -> Optional["LoRAAdapter"]:
|
|
if loaded_keys is None:
|
|
loaded_keys = set()
|
|
|
|
reshape_name = "{}.reshape_weight".format(x)
|
|
regular_lora = "{}.lora_up.weight".format(x)
|
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
mochi_lora = "{}.lora_B".format(x)
|
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
A_name = None
|
|
|
|
if regular_lora in lora.keys():
|
|
A_name = regular_lora
|
|
B_name = "{}.lora_down.weight".format(x)
|
|
mid_name = "{}.lora_mid.weight".format(x)
|
|
elif diffusers_lora in lora.keys():
|
|
A_name = diffusers_lora
|
|
B_name = "{}_lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers2_lora in lora.keys():
|
|
A_name = diffusers2_lora
|
|
B_name = "{}.lora_A.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers3_lora in lora.keys():
|
|
A_name = diffusers3_lora
|
|
B_name = "{}.lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif mochi_lora in lora.keys():
|
|
A_name = mochi_lora
|
|
B_name = "{}.lora_A".format(x)
|
|
mid_name = None
|
|
elif transformers_lora in lora.keys():
|
|
A_name = transformers_lora
|
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
|
mid_name = None
|
|
|
|
if A_name is not None:
|
|
mid = None
|
|
if mid_name is not None and mid_name in lora.keys():
|
|
mid = lora[mid_name]
|
|
loaded_keys.add(mid_name)
|
|
reshape = None
|
|
if reshape_name in lora.keys():
|
|
try:
|
|
reshape = lora[reshape_name].tolist()
|
|
loaded_keys.add(reshape_name)
|
|
except:
|
|
pass
|
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
|
loaded_keys.add(A_name)
|
|
loaded_keys.add(B_name)
|
|
return cls(loaded_keys, weights)
|
|
else:
|
|
return None
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
v = self.weights
|
|
mat1 = comfy.model_management.cast_to_device(
|
|
v[0], weight.device, intermediate_dtype
|
|
)
|
|
mat2 = comfy.model_management.cast_to_device(
|
|
v[1], weight.device, intermediate_dtype
|
|
)
|
|
dora_scale = v[4]
|
|
reshape = v[5]
|
|
|
|
if reshape is not None:
|
|
weight = pad_tensor_to_shape(weight, reshape)
|
|
|
|
if v[2] is not None:
|
|
alpha = v[2] / mat2.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
if v[3] is not None:
|
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
mat3 = comfy.model_management.cast_to_device(
|
|
v[3], weight.device, intermediate_dtype
|
|
)
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
mat2 = (
|
|
torch.mm(
|
|
mat2.transpose(0, 1).flatten(start_dim=1),
|
|
mat3.transpose(0, 1).flatten(start_dim=1),
|
|
)
|
|
.reshape(final_shape)
|
|
.transpose(0, 1)
|
|
)
|
|
try:
|
|
lora_diff = torch.mm(
|
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
|
).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = weight_decompose(
|
|
dora_scale,
|
|
weight,
|
|
lora_diff,
|
|
alpha,
|
|
strength,
|
|
intermediate_dtype,
|
|
function,
|
|
)
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
return weight
|