From cb63e230b41193601e48778111eff045391cfbe2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 14:15:09 -0500 Subject: [PATCH] Make lora code a bit cleaner. --- comfy/lora.py | 14 +++++++------- comfy/model_patcher.py | 14 +++++++++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 29c59d893..ecd518084 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -43,7 +43,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -64,7 +64,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -116,7 +116,7 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) w_norm_name = "{}.w_norm".format(x) @@ -126,21 +126,21 @@ def load_lora(lora, to_load): if w_norm is not None: loaded_keys.add(w_norm_name) - patch_dict[to_load[x]] = (w_norm,) + patch_dict[to_load[x]] = ("diff", (w_norm,)) if b_norm is not None: loaded_keys.add(b_norm_name) - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) diff_name = "{}.diff".format(x) diff_weight = lora.get(diff_name, None) if diff_weight is not None: - patch_dict[to_load[x]] = (diff_weight,) + patch_dict[to_load[x]] = ("diff", (diff_weight,)) loaded_keys.add(diff_name) diff_bias_name = "{}.diff_b".format(x) diff_bias = lora.get(diff_bias_name, None) if diff_bias is not None: - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) for x in lora.keys(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a3cffc3be..d78cdfd4d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -217,13 +217,19 @@ class ModelPatcher: v = (self.calculate_weight(v[1:], v[0].clone(), key), ) if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) - elif len(v) == 4: #lora/locon + elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: @@ -237,7 +243,7 @@ class ModelPatcher: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) - elif len(v) == 8: #lokr + elif patch_type == "lokr": w1 = v[0] w2 = v[1] w1_a = v[3] @@ -276,7 +282,7 @@ class ModelPatcher: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) - else: #loha + elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: @@ -305,6 +311,8 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) + else: + print("patch type not recognized", patch_type, key) return weight