ComfyUI/comfy/weight_adapter/glora.py
Kohaku-Blueleaf 4774c3244e Initial impl
LoRA load/calculate_weight
LoHa/LoKr/GLoRA load
2025-04-02 09:21:39 +08:00

55 lines
1.4 KiB
Python

import logging
import torch
import comfy.utils
import comfy.model_management
import comfy.model_base
from comfy.lora import weight_decompose, pad_tensor_to_shape
from .base import WeightAdapterBase
class GLoRAAdapter(WeightAdapterBase):
name = "glora"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> "GLoRAAdapter" | None:
if loaded_keys is None:
loaded_keys = set()
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
weights = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
pass