From 6fb4cc0179805245f136fe5d44bf8cf6de2c83bb Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:21:17 +0800 Subject: [PATCH] Weight Adapter Scheme --- comfy/weight_adapter/__init__.py | 13 +++++++++++ comfy/weight_adapter/base.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 comfy/weight_adapter/__init__.py create mode 100644 comfy/weight_adapter/base.py diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py new file mode 100644 index 00000000..a021cfd0 --- /dev/null +++ b/comfy/weight_adapter/__init__.py @@ -0,0 +1,13 @@ +from .base import WeightAdapterBase +from .lora import LoRAAdapter +from .loha import LoHaAdapter +from .lokr import LoKrAdapter +from .glora import GLoRAAdapter + + +adapters: list[type[WeightAdapterBase]] = [ + LoRAAdapter, + LoHaAdapter, + LoKrAdapter, + GLoRAAdapter, +] \ No newline at end of file diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py new file mode 100644 index 00000000..2ba2850e --- /dev/null +++ b/comfy/weight_adapter/base.py @@ -0,0 +1,37 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WeightAdapterBase: + name: str + loaded_keys: set[str] + weights: list[torch.Tensor] + + @classmethod + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None: + raise NotImplementedError + + def to_train(self) -> "WeightAdapterTrainBase": + raise NotImplementedError + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + raise NotImplementedError + + +class WeightAdapterTrainBase(nn.Module): + def __init__(self): + super().__init__() + + # [TODO] Collaborate with LoRA training PR #7032