Make lora code a bit cleaner.

This commit is contained in:
comfyanonymous 2023-12-09 14:15:09 -05:00
parent 9e411073e9
commit cb63e230b4
2 changed files with 18 additions and 10 deletions

View File

@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name) loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
@ -116,7 +116,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name) loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
w_norm_name = "{}.w_norm".format(x) w_norm_name = "{}.w_norm".format(x)
@ -126,21 +126,21 @@ def load_lora(lora, to_load):
if w_norm is not None: if w_norm is not None:
loaded_keys.add(w_norm_name) loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,) patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None: if b_norm is not None:
loaded_keys.add(b_norm_name) loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x) diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None) diff_weight = lora.get(diff_name, None)
if diff_weight is not None: if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,) patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name) loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x) diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None) diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None: if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name) loaded_keys.add(diff_bias_name)
for x in lora.keys(): for x in lora.keys():

View File

@ -217,13 +217,19 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), ) v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1: if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0] w1 = v[0]
if alpha != 0.0: if alpha != 0.0:
if w1.shape != weight.shape: if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else: else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None: if v[2] is not None:
@ -237,7 +243,7 @@ class ModelPatcher:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) print("ERROR", key, e)
elif len(v) == 8: #lokr elif patch_type == "lokr":
w1 = v[0] w1 = v[0]
w2 = v[1] w2 = v[1]
w1_a = v[3] w1_a = v[3]
@ -276,7 +282,7 @@ class ModelPatcher:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) print("ERROR", key, e)
else: #loha elif patch_type == "loha":
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
if v[2] is not None: if v[2] is not None:
@ -305,6 +311,8 @@ class ModelPatcher:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) print("ERROR", key, e)
else:
print("patch type not recognized", patch_type, key)
return weight return weight