From e8f3bc5ab766ba0149ac8ba8f67c328f8bb3da7d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 9 Apr 2025 09:16:52 +0800 Subject: [PATCH] Finalize the modularized weight adapter impl * LoRA/LoHa/LoKr/GLoRA working well * Removed TONS of code in lora.py --- comfy/lora.py | 180 +--------------------------------- comfy/weight_adapter/glora.py | 49 ++++++++- comfy/weight_adapter/loha.py | 45 ++++++++- comfy/weight_adapter/lokr.py | 54 +++++++++- 4 files changed, 143 insertions(+), 185 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index ab053e7d3..8760a21fb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -282,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}): return key_map -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) - lora_diff *= alpha - weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * (weight_calc) - else: - weight[:] = weight_calc - return weight - def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: """ Pad a tensor to a new shape with zeros. @@ -358,14 +338,13 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori if isinstance(v, weight_adapter.WeightAdapterBase): output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) - if output is not None: + if output is None: + logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) + else: weight = output if old_weight is not None: weight = old_weight - continue - else: - #Fallback when calculate_weight haven't implemented - v = v.weights + continue if len(v) == 1: patch_type = "diff" @@ -393,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) - dora_scale = v[4] - reshape = v[5] - - if reshape is not None: - weight = pad_tensor_to_shape(weight, reshape) - - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) - else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) - else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) - else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - dora_scale = v[5] - - old_glora = False - if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: - rank = v[0].shape[0] - old_glora = True - - if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: - pass - else: - old_glora = False - rank = v[1].shape[0] - - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) - - if v[4] is not None: - alpha = v[4] / rank - else: - alpha = 1.0 - - try: - if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora - else: - if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - lora_diff += torch.mm(b1, b2).reshape(weight.shape) - - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: logging.warning("patch type not recognized {} {}".format(patch_type, key)) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 206fc6e28..939abbba5 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,9 @@ +import logging from typing import Optional import torch -from .base import WeightAdapterBase +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose class GLoRAAdapter(WeightAdapterBase): @@ -27,7 +29,7 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) + weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -47,4 +49,45 @@ class GLoRAAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - pass + v = self.weights + dora_scale = v[5] + + old_glora = False + if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: + rank = v[0].shape[0] + old_glora = True + + if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: + if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + pass + else: + old_glora = False + rank = v[1].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + + if v[4] is not None: + alpha = v[4] / rank + else: + alpha = 1.0 + + try: + if old_glora: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + else: + if weight.dim() > 2: + lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + else: + lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff += torch.mm(b1, b2).reshape(weight.shape) + + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 7a6cab243..ce79abad5 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,7 +1,9 @@ +import logging from typing import Optional import torch -from .base import WeightAdapterBase +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose class LoHaAdapter(WeightAdapterBase): @@ -38,7 +40,7 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) + weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -58,4 +60,41 @@ class LoHaAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - pass + v = self.weights + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + else: + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index da1b15519..51233db2d 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -1,7 +1,9 @@ +import logging from typing import Optional import torch -from .base import WeightAdapterBase +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose class LoKrAdapter(WeightAdapterBase): @@ -66,7 +68,7 @@ class LoKrAdapter(WeightAdapterBase): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) + weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) return cls(loaded_keys, weights) else: return None @@ -82,4 +84,50 @@ class LoKrAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - pass + v = self.weights + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + else: + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + else: + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight