Support loha that use cp decomposition.

This commit is contained in:
comfyanonymous 2023-03-23 04:32:25 -04:00
parent 94a7c895f4
commit dd095efc2c

View File

@ -149,8 +149,18 @@ def load_lora(path, to_load):
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name])
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -312,7 +322,16 @@ class ModelPatcher:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
weight += (alpha * torch.mm(w1a.float(), w1b.float()) * torch.mm(w2a.float(), w2b.float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
else:
m1 = torch.mm(w1a.float(), w1b.float())
m2 = torch.mm(w2a.float(), w2b.float())
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
return self.model
def unpatch_model(self):
model_sd = self.model.state_dict()