ComfyUI/comfy/weight_adapter/base.py
Kohaku-Blueleaf 84317474fd lint
2025-04-02 09:31:24 +08:00

36 lines
768 B
Python

import torch
import torch.nn as nn
class WeightAdapterBase:
name: str
loaded_keys: set[str]
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
raise NotImplementedError
class WeightAdapterTrainBase(nn.Module):
def __init__(self):
super().__init__()
# [TODO] Collaborate with LoRA training PR #7032