mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-15 16:13:29 +00:00
Weight Adapter Scheme
This commit is contained in:
parent
50614f1b79
commit
6fb4cc0179
13
comfy/weight_adapter/__init__.py
Normal file
13
comfy/weight_adapter/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .base import WeightAdapterBase
|
||||
from .lora import LoRAAdapter
|
||||
from .loha import LoHaAdapter
|
||||
from .lokr import LoKrAdapter
|
||||
from .glora import GLoRAAdapter
|
||||
|
||||
|
||||
adapters: list[type[WeightAdapterBase]] = [
|
||||
LoRAAdapter,
|
||||
LoHaAdapter,
|
||||
LoKrAdapter,
|
||||
GLoRAAdapter,
|
||||
]
|
37
comfy/weight_adapter/base.py
Normal file
37
comfy/weight_adapter/base.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class WeightAdapterBase:
|
||||
name: str
|
||||
loaded_keys: set[str]
|
||||
weights: list[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_train(self) -> "WeightAdapterTrainBase":
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WeightAdapterTrainBase(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# [TODO] Collaborate with LoRA training PR #7032
|
Loading…
Reference in New Issue
Block a user