ComfyUI/comfy/weight_adapter/lora.py
2025-04-08 22:01:43 +08:00

143 lines
4.6 KiB
Python

import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
class LoRAAdapter(WeightAdapterBase):
name = "lora"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["LoRAAdapter"]:
if loaded_keys is None:
loaded_keys = set()
reshape_name = "{}.reshape_weight".format(x)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif diffusers3_lora in lora.keys():
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name = "{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
loaded_keys.add(A_name)
loaded_keys.add(B_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
mat1 = comfy.model_management.cast_to_device(
v[0], weight.device, intermediate_dtype
)
mat2 = comfy.model_management.cast_to_device(
v[1], weight.device, intermediate_dtype
)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(
v[3], weight.device, intermediate_dtype
)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = (
torch.mm(
mat2.transpose(0, 1).flatten(start_dim=1),
mat3.transpose(0, 1).flatten(start_dim=1),
)
.reshape(final_shape)
.transpose(0, 1)
)
try:
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight