mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Make lora code a bit cleaner.
This commit is contained in:
parent
9e411073e9
commit
cb63e230b4
@ -43,7 +43,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@ -64,7 +64,7 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
@ -116,7 +116,7 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
||||
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
@ -126,21 +126,21 @@ def load_lora(lora, to_load):
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = (w_norm,)
|
||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = (diff_weight,)
|
||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
|
@ -217,13 +217,19 @@ class ModelPatcher:
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif len(v) == 4: #lora/locon
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
if v[2] is not None:
|
||||
@ -237,7 +243,7 @@ class ModelPatcher:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif len(v) == 8: #lokr
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
@ -276,7 +282,7 @@ class ModelPatcher:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
else: #loha
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
@ -305,6 +311,8 @@ class ModelPatcher:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
else:
|
||||
print("patch type not recognized", patch_type, key)
|
||||
|
||||
return weight
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user