Weight Adapter Scheme

This commit is contained in:
Kohaku-Blueleaf 2025-04-02 09:21:17 +08:00
parent 50614f1b79
commit 6fb4cc0179
2 changed files with 50 additions and 0 deletions

View File

@ -0,0 +1,13 @@
from .base import WeightAdapterBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
from .glora import GLoRAAdapter
adapters: list[type[WeightAdapterBase]] = [
LoRAAdapter,
LoHaAdapter,
LoKrAdapter,
GLoRAAdapter,
]

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class WeightAdapterBase:
name: str
loaded_keys: set[str]
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
raise NotImplementedError
class WeightAdapterTrainBase(nn.Module):
def __init__(self):
super().__init__()
# [TODO] Collaborate with LoRA training PR #7032