ComfyUI/comfy/weight_adapter/glora.py
Kohaku-Blueleaf e8f3bc5ab7 Finalize the modularized weight adapter impl
* LoRA/LoHa/LoKr/GLoRA working well
* Removed TONS of code in lora.py
2025-04-09 09:16:52 +08:00

94 lines
3.4 KiB
Python

import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class GLoRAAdapter(WeightAdapterBase):
name = "glora"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["GLoRAAdapter"]:
if loaded_keys is None:
loaded_keys = set()
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight