diff --git a/comfy/lora.py b/comfy/lora.py index bc9f3022a..8760a21fb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -20,6 +20,7 @@ from __future__ import annotations import comfy.utils import comfy.model_management import comfy.model_base +import comfy.weight_adapter as weight_adapter import logging import torch @@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True): dora_scale = lora[dora_scale_name] loaded_keys.add(dora_scale_name) - reshape_name = "{}.reshape_weight".format(x) - reshape = None - if reshape_name in lora.keys(): - try: - reshape = lora[reshape_name].tolist() - loaded_keys.add(reshape_name) - except: - pass - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - diffusers2_lora = "{}.lora_B.weight".format(x) - diffusers3_lora = "{}.lora.up.weight".format(x) - mochi_lora = "{}.lora_B".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif diffusers2_lora in lora.keys(): - A_name = diffusers2_lora - B_name = "{}.lora_A.weight".format(x) - mid_name = None - elif diffusers3_lora in lora.keys(): - A_name = diffusers3_lora - B_name = "{}.lora.down.weight".format(x) - mid_name = None - elif mochi_lora in lora.keys(): - A_name = mochi_lora - B_name = "{}.lora_A".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) - - #glora - a1_name = "{}.a1.weight".format(x) - a2_name = "{}.a2.weight".format(x) - b1_name = "{}.b1.weight".format(x) - b2_name = "{}.b2.weight".format(x) - if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) - loaded_keys.add(a1_name) - loaded_keys.add(a2_name) - loaded_keys.add(b1_name) - loaded_keys.add(b2_name) + for adapter_cls in weight_adapter.adapters: + adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys) + if adapter is not None: + patch_dict[to_load[x]] = adapter + loaded_keys.update(adapter.loaded_keys) + continue w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) @@ -408,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}): return key_map -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) - lora_diff *= alpha - weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * (weight_calc) - else: - weight[:] = weight_calc - return weight - def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: """ Pad a tensor to a new shape with zeros. @@ -482,6 +336,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori if isinstance(v, list): v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) + if isinstance(v, weight_adapter.WeightAdapterBase): + output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) + if output is None: + logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) + else: + weight = output + if old_weight is not None: + weight = old_weight + continue + if len(v) == 1: patch_type = "diff" elif len(v) == 2: @@ -508,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) - dora_scale = v[4] - reshape = v[5] - - if reshape is not None: - weight = pad_tensor_to_shape(weight, reshape) - - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) - else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) - else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) - else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - dora_scale = v[5] - - old_glora = False - if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: - rank = v[0].shape[0] - old_glora = True - - if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: - pass - else: - old_glora = False - rank = v[1].shape[0] - - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) - - if v[4] is not None: - alpha = v[4] / rank - else: - alpha = 1.0 - - try: - if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora - else: - if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - lora_diff += torch.mm(b1, b2).reshape(weight.shape) - - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: logging.warning("patch type not recognized {} {}".format(patch_type, key)) diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py new file mode 100644 index 000000000..e6cd805b6 --- /dev/null +++ b/comfy/weight_adapter/__init__.py @@ -0,0 +1,13 @@ +from .base import WeightAdapterBase +from .lora import LoRAAdapter +from .loha import LoHaAdapter +from .lokr import LoKrAdapter +from .glora import GLoRAAdapter + + +adapters: list[type[WeightAdapterBase]] = [ + LoRAAdapter, + LoHaAdapter, + LoKrAdapter, + GLoRAAdapter, +] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py new file mode 100644 index 000000000..54af3babe --- /dev/null +++ b/comfy/weight_adapter/base.py @@ -0,0 +1,94 @@ +from typing import Optional + +import torch +import torch.nn as nn + +import comfy.model_management + + +class WeightAdapterBase: + name: str + loaded_keys: set[str] + weights: list[torch.Tensor] + + @classmethod + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + raise NotImplementedError + + def to_train(self) -> "WeightAdapterTrainBase": + raise NotImplementedError + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + raise NotImplementedError + + +class WeightAdapterTrainBase(nn.Module): + def __init__(self): + super().__init__() + + # [TODO] Collaborate with LoRA training PR #7032 + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) + lora_diff *= alpha + weight_calc = weight + function(lora_diff).type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + + +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py new file mode 100644 index 000000000..939abbba5 --- /dev/null +++ b/comfy/weight_adapter/glora.py @@ -0,0 +1,93 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class GLoRAAdapter(WeightAdapterBase): + name = "glora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["GLoRAAdapter"]: + if loaded_keys is None: + loaded_keys = set() + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + dora_scale = v[5] + + old_glora = False + if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: + rank = v[0].shape[0] + old_glora = True + + if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: + if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + pass + else: + old_glora = False + rank = v[1].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + + if v[4] is not None: + alpha = v[4] / rank + else: + alpha = 1.0 + + try: + if old_glora: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + else: + if weight.dim() > 2: + lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + else: + lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff += torch.mm(b1, b2).reshape(weight.shape) + + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py new file mode 100644 index 000000000..ce79abad5 --- /dev/null +++ b/comfy/weight_adapter/loha.py @@ -0,0 +1,100 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class LoHaAdapter(WeightAdapterBase): + name = "loha" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoHaAdapter"]: + if loaded_keys is None: + loaded_keys = set() + + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + else: + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py new file mode 100644 index 000000000..51233db2d --- /dev/null +++ b/comfy/weight_adapter/lokr.py @@ -0,0 +1,133 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class LoKrAdapter(WeightAdapterBase): + name = "lokr" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoKrAdapter"]: + if loaded_keys is None: + loaded_keys = set() + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + else: + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + else: + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py new file mode 100644 index 000000000..b2e623924 --- /dev/null +++ b/comfy/weight_adapter/lora.py @@ -0,0 +1,142 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape + + +class LoRAAdapter(WeightAdapterBase): + name = "lora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoRAAdapter"]: + if loaded_keys is None: + loaded_keys = set() + + reshape_name = "{}.reshape_weight".format(x) + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) + mochi_lora = "{}.lora_B".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None + elif mochi_lora in lora.keys(): + A_name = mochi_lora + B_name = "{}.lora_A".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name = "{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + reshape = None + if reshape_name in lora.keys(): + try: + reshape = lora[reshape_name].tolist() + loaded_keys.add(reshape_name) + except: + pass + weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + mat1 = comfy.model_management.cast_to_device( + v[0], weight.device, intermediate_dtype + ) + mat2 = comfy.model_management.cast_to_device( + v[1], weight.device, intermediate_dtype + ) + dora_scale = v[4] + reshape = v[5] + + if reshape is not None: + weight = pad_tensor_to_shape(weight, reshape) + + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + # locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = comfy.model_management.cast_to_device( + v[3], weight.device, intermediate_dtype + ) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = ( + torch.mm( + mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1), + ) + .reshape(final_shape) + .transpose(0, 1) + ) + try: + lora_diff = torch.mm( + mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) + ).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight