import logging import torch import comfy.utils import comfy.model_management import comfy.model_base from comfy.lora import weight_decompose, pad_tensor_to_shape from .base import WeightAdapterBase class LoKrAdapter(WeightAdapterBase): name = "lokr" def __init__(self, loaded_keys, weights): self.loaded_keys = loaded_keys self.weights = weights @classmethod def load( cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, ) -> "LoKrAdapter" | None: if loaded_keys is None: loaded_keys = set() lokr_w1_name = "{}.lokr_w1".format(x) lokr_w2_name = "{}.lokr_w2".format(x) lokr_w1_a_name = "{}.lokr_w1_a".format(x) lokr_w1_b_name = "{}.lokr_w1_b".format(x) lokr_t2_name = "{}.lokr_t2".format(x) lokr_w2_a_name = "{}.lokr_w2_a".format(x) lokr_w2_b_name = "{}.lokr_w2_b".format(x) lokr_w1 = None if lokr_w1_name in lora.keys(): lokr_w1 = lora[lokr_w1_name] loaded_keys.add(lokr_w1_name) lokr_w2 = None if lokr_w2_name in lora.keys(): lokr_w2 = lora[lokr_w2_name] loaded_keys.add(lokr_w2_name) lokr_w1_a = None if lokr_w1_a_name in lora.keys(): lokr_w1_a = lora[lokr_w1_a_name] loaded_keys.add(lokr_w1_a_name) lokr_w1_b = None if lokr_w1_b_name in lora.keys(): lokr_w1_b = lora[lokr_w1_b_name] loaded_keys.add(lokr_w1_b_name) lokr_w2_a = None if lokr_w2_a_name in lora.keys(): lokr_w2_a = lora[lokr_w2_a_name] loaded_keys.add(lokr_w2_a_name) lokr_w2_b = None if lokr_w2_b_name in lora.keys(): lokr_w2_b = lora[lokr_w2_b_name] loaded_keys.add(lokr_w2_b_name) lokr_t2 = None if lokr_t2_name in lora.keys(): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): weights = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) return cls(loaded_keys, weights) else: return None def calculate_weight( self, weight, key, strength, strength_model, offset, function, intermediate_dtype=torch.float32, original_weight=None, ): pass