From 4774c3244eaa93c0f089c85e6967be7d76f93342 Mon Sep 17 00:00:00 2001
From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Date: Wed, 2 Apr 2025 09:21:39 +0800
Subject: [PATCH] Initial impl

LoRA load/calculate_weight
LoHa/LoKr/GLoRA load
---
 comfy/weight_adapter/glora.py |  54 +++++++++++++
 comfy/weight_adapter/loha.py  |  65 +++++++++++++++
 comfy/weight_adapter/lokr.py  |  89 +++++++++++++++++++++
 comfy/weight_adapter/lora.py  | 144 ++++++++++++++++++++++++++++++++++
 4 files changed, 352 insertions(+)
 create mode 100644 comfy/weight_adapter/glora.py
 create mode 100644 comfy/weight_adapter/loha.py
 create mode 100644 comfy/weight_adapter/lokr.py
 create mode 100644 comfy/weight_adapter/lora.py

diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py
new file mode 100644
index 00000000..bdb9220e
--- /dev/null
+++ b/comfy/weight_adapter/glora.py
@@ -0,0 +1,54 @@
+import logging
+import torch
+import comfy.utils
+import comfy.model_management
+import comfy.model_base
+from comfy.lora import weight_decompose, pad_tensor_to_shape
+
+from .base import WeightAdapterBase
+
+
+class GLoRAAdapter(WeightAdapterBase):
+    name = "glora"
+
+    def __init__(self, loaded_keys, weights):
+        self.loaded_keys = loaded_keys
+        self.weights = weights
+
+    @classmethod
+    def load(
+        cls,
+        x: str,
+        lora: dict[str, torch.Tensor],
+        alpha: float,
+        dora_scale: torch.Tensor,
+        loaded_keys: set[str] = None,
+    ) -> "GLoRAAdapter" | None:
+        if loaded_keys is None:
+            loaded_keys = set()
+        a1_name = "{}.a1.weight".format(x)
+        a2_name = "{}.a2.weight".format(x)
+        b1_name = "{}.b1.weight".format(x)
+        b2_name = "{}.b2.weight".format(x)
+        if a1_name in lora:
+            weights = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
+            loaded_keys.add(a1_name)
+            loaded_keys.add(a2_name)
+            loaded_keys.add(b1_name)
+            loaded_keys.add(b2_name)
+            return cls(loaded_keys, weights)
+        else:
+            return None
+
+    def calculate_weight(
+        self,
+        weight,
+        key,
+        strength,
+        strength_model,
+        offset,
+        function,
+        intermediate_dtype=torch.float32,
+        original_weight=None,
+    ):
+        pass
diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py
new file mode 100644
index 00000000..b3267bc0
--- /dev/null
+++ b/comfy/weight_adapter/loha.py
@@ -0,0 +1,65 @@
+import logging
+import torch
+import comfy.utils
+import comfy.model_management
+import comfy.model_base
+from comfy.lora import weight_decompose, pad_tensor_to_shape
+
+from .base import WeightAdapterBase
+
+
+class LoHaAdapter(WeightAdapterBase):
+    name = "loha"
+
+    def __init__(self, loaded_keys, weights):
+        self.loaded_keys = loaded_keys
+        self.weights = weights
+
+    @classmethod
+    def load(
+        cls,
+        x: str,
+        lora: dict[str, torch.Tensor],
+        alpha: float,
+        dora_scale: torch.Tensor,
+        loaded_keys: set[str] = None,
+    ) -> "LoHaAdapter" | None:
+        if loaded_keys is None:
+            loaded_keys = set()
+
+        hada_w1_a_name = "{}.hada_w1_a".format(x)
+        hada_w1_b_name = "{}.hada_w1_b".format(x)
+        hada_w2_a_name = "{}.hada_w2_a".format(x)
+        hada_w2_b_name = "{}.hada_w2_b".format(x)
+        hada_t1_name = "{}.hada_t1".format(x)
+        hada_t2_name = "{}.hada_t2".format(x)
+        if hada_w1_a_name in lora.keys():
+            hada_t1 = None
+            hada_t2 = None
+            if hada_t1_name in lora.keys():
+                hada_t1 = lora[hada_t1_name]
+                hada_t2 = lora[hada_t2_name]
+                loaded_keys.add(hada_t1_name)
+                loaded_keys.add(hada_t2_name)
+
+            weights = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
+            loaded_keys.add(hada_w1_a_name)
+            loaded_keys.add(hada_w1_b_name)
+            loaded_keys.add(hada_w2_a_name)
+            loaded_keys.add(hada_w2_b_name)
+            return cls(loaded_keys, weights)
+        else:
+            return None
+
+    def calculate_weight(
+        self,
+        weight,
+        key,
+        strength,
+        strength_model,
+        offset,
+        function,
+        intermediate_dtype=torch.float32,
+        original_weight=None,
+    ):
+        pass
diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py
new file mode 100644
index 00000000..206c80ae
--- /dev/null
+++ b/comfy/weight_adapter/lokr.py
@@ -0,0 +1,89 @@
+import logging
+import torch
+import comfy.utils
+import comfy.model_management
+import comfy.model_base
+from comfy.lora import weight_decompose, pad_tensor_to_shape
+
+from .base import WeightAdapterBase
+
+
+class LoKrAdapter(WeightAdapterBase):
+    name = "lokr"
+
+    def __init__(self, loaded_keys, weights):
+        self.loaded_keys = loaded_keys
+        self.weights = weights
+
+    @classmethod
+    def load(
+        cls,
+        x: str,
+        lora: dict[str, torch.Tensor],
+        alpha: float,
+        dora_scale: torch.Tensor,
+        loaded_keys: set[str] = None,
+    ) -> "LoKrAdapter" | None:
+        if loaded_keys is None:
+            loaded_keys = set()
+        lokr_w1_name = "{}.lokr_w1".format(x)
+        lokr_w2_name = "{}.lokr_w2".format(x)
+        lokr_w1_a_name = "{}.lokr_w1_a".format(x)
+        lokr_w1_b_name = "{}.lokr_w1_b".format(x)
+        lokr_t2_name = "{}.lokr_t2".format(x)
+        lokr_w2_a_name = "{}.lokr_w2_a".format(x)
+        lokr_w2_b_name = "{}.lokr_w2_b".format(x)
+
+        lokr_w1 = None
+        if lokr_w1_name in lora.keys():
+            lokr_w1 = lora[lokr_w1_name]
+            loaded_keys.add(lokr_w1_name)
+
+        lokr_w2 = None
+        if lokr_w2_name in lora.keys():
+            lokr_w2 = lora[lokr_w2_name]
+            loaded_keys.add(lokr_w2_name)
+
+        lokr_w1_a = None
+        if lokr_w1_a_name in lora.keys():
+            lokr_w1_a = lora[lokr_w1_a_name]
+            loaded_keys.add(lokr_w1_a_name)
+
+        lokr_w1_b = None
+        if lokr_w1_b_name in lora.keys():
+            lokr_w1_b = lora[lokr_w1_b_name]
+            loaded_keys.add(lokr_w1_b_name)
+
+        lokr_w2_a = None
+        if lokr_w2_a_name in lora.keys():
+            lokr_w2_a = lora[lokr_w2_a_name]
+            loaded_keys.add(lokr_w2_a_name)
+
+        lokr_w2_b = None
+        if lokr_w2_b_name in lora.keys():
+            lokr_w2_b = lora[lokr_w2_b_name]
+            loaded_keys.add(lokr_w2_b_name)
+
+        lokr_t2 = None
+        if lokr_t2_name in lora.keys():
+            lokr_t2 = lora[lokr_t2_name]
+            loaded_keys.add(lokr_t2_name)
+
+        if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
+            weights = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
+            return cls(loaded_keys, weights)
+        else:
+            return None
+
+    def calculate_weight(
+        self,
+        weight,
+        key,
+        strength,
+        strength_model,
+        offset,
+        function,
+        intermediate_dtype=torch.float32,
+        original_weight=None,
+    ):
+        pass
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py
new file mode 100644
index 00000000..b79bfd82
--- /dev/null
+++ b/comfy/weight_adapter/lora.py
@@ -0,0 +1,144 @@
+import logging
+import torch
+import comfy.utils
+import comfy.model_management
+import comfy.model_base
+from comfy.lora import weight_decompose, pad_tensor_to_shape
+
+from .base import WeightAdapterBase
+
+
+class LoRAAdapter(WeightAdapterBase):
+    name = "lora"
+
+    def __init__(self, loaded_keys, weights):
+        self.loaded_keys = loaded_keys
+        self.weights = weights
+
+    @classmethod
+    def load(
+        cls,
+        x: str,
+        lora: dict[str, torch.Tensor],
+        alpha: float,
+        dora_scale: torch.Tensor,
+        loaded_keys: set[str] = None,
+    ) -> "LoRAAdapter" | None:
+        if loaded_keys is None:
+            loaded_keys = set()
+
+        reshape_name = "{}.reshape_weight".format(x)
+        regular_lora = "{}.lora_up.weight".format(x)
+        diffusers_lora = "{}_lora.up.weight".format(x)
+        diffusers2_lora = "{}.lora_B.weight".format(x)
+        diffusers3_lora = "{}.lora.up.weight".format(x)
+        mochi_lora = "{}.lora_B".format(x)
+        transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
+        A_name = None
+
+        if regular_lora in lora.keys():
+            A_name = regular_lora
+            B_name = "{}.lora_down.weight".format(x)
+            mid_name = "{}.lora_mid.weight".format(x)
+        elif diffusers_lora in lora.keys():
+            A_name = diffusers_lora
+            B_name = "{}_lora.down.weight".format(x)
+            mid_name = None
+        elif diffusers2_lora in lora.keys():
+            A_name = diffusers2_lora
+            B_name = "{}.lora_A.weight".format(x)
+            mid_name = None
+        elif diffusers3_lora in lora.keys():
+            A_name = diffusers3_lora
+            B_name = "{}.lora.down.weight".format(x)
+            mid_name = None
+        elif mochi_lora in lora.keys():
+            A_name = mochi_lora
+            B_name = "{}.lora_A".format(x)
+            mid_name = None
+        elif transformers_lora in lora.keys():
+            A_name = transformers_lora
+            B_name = "{}.lora_linear_layer.down.weight".format(x)
+            mid_name = None
+
+        if A_name is not None:
+            mid = None
+            if mid_name is not None and mid_name in lora.keys():
+                mid = lora[mid_name]
+                loaded_keys.add(mid_name)
+            reshape = None
+            if reshape_name in lora.keys():
+                try:
+                    reshape = lora[reshape_name].tolist()
+                    loaded_keys.add(reshape_name)
+                except:
+                    pass
+            weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
+            loaded_keys.add(A_name)
+            loaded_keys.add(B_name)
+            return cls(loaded_keys, weights)
+        else:
+            return None
+
+    def calculate_weight(
+        self,
+        weight,
+        key,
+        strength,
+        strength_model,
+        offset,
+        function,
+        intermediate_dtype=torch.float32,
+        original_weight=None,
+    ):
+        v = self.weights[1]
+        mat1 = comfy.model_management.cast_to_device(
+            v[0], weight.device, intermediate_dtype
+        )
+        mat2 = comfy.model_management.cast_to_device(
+            v[1], weight.device, intermediate_dtype
+        )
+        dora_scale = v[4]
+        reshape = v[5]
+
+        if reshape is not None:
+            weight = pad_tensor_to_shape(weight, reshape)
+
+        if v[2] is not None:
+            alpha = v[2] / mat2.shape[0]
+        else:
+            alpha = 1.0
+
+        if v[3] is not None:
+            # locon mid weights, hopefully the math is fine because I didn't properly test it
+            mat3 = comfy.model_management.cast_to_device(
+                v[3], weight.device, intermediate_dtype
+            )
+            final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
+            mat2 = (
+                torch.mm(
+                    mat2.transpose(0, 1).flatten(start_dim=1),
+                    mat3.transpose(0, 1).flatten(start_dim=1),
+                )
+                .reshape(final_shape)
+                .transpose(0, 1)
+            )
+        try:
+            lora_diff = torch.mm(
+                mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
+            ).reshape(weight.shape)
+            if dora_scale is not None:
+                weight = weight_decompose(
+                    dora_scale,
+                    weight,
+                    lora_diff,
+                    alpha,
+                    strength,
+                    intermediate_dtype,
+                    function,
+                )
+            else:
+                weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
+        except Exception as e:
+            logging.error("ERROR {} {} {}".format(self.name, key, e))
+        return weight