mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
dora_scale support for lora file.
This commit is contained in:
parent
c6de09b02e
commit
ae77590b4e
@ -21,6 +21,12 @@ def load_lora(lora, to_load):
|
|||||||
alpha = lora[alpha_name].item()
|
alpha = lora[alpha_name].item()
|
||||||
loaded_keys.add(alpha_name)
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
|
dora_scale_name = "{}.dora_scale".format(x)
|
||||||
|
dora_scale = None
|
||||||
|
if dora_scale_name in lora.keys():
|
||||||
|
dora_scale = lora[dora_scale_name]
|
||||||
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
@ -44,7 +50,7 @@ def load_lora(lora, to_load):
|
|||||||
if mid_name is not None and mid_name in lora.keys():
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
loaded_keys.add(mid_name)
|
loaded_keys.add(mid_name)
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||||
loaded_keys.add(A_name)
|
loaded_keys.add(A_name)
|
||||||
loaded_keys.add(B_name)
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
@ -65,7 +71,7 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(hada_t1_name)
|
loaded_keys.add(hada_t1_name)
|
||||||
loaded_keys.add(hada_t2_name)
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||||
loaded_keys.add(hada_w1_a_name)
|
loaded_keys.add(hada_w1_a_name)
|
||||||
loaded_keys.add(hada_w1_b_name)
|
loaded_keys.add(hada_w1_b_name)
|
||||||
loaded_keys.add(hada_w2_a_name)
|
loaded_keys.add(hada_w2_a_name)
|
||||||
@ -117,7 +123,7 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(lokr_t2_name)
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||||
|
|
||||||
#glora
|
#glora
|
||||||
a1_name = "{}.a1.weight".format(x)
|
a1_name = "{}.a1.weight".format(x)
|
||||||
@ -125,7 +131,7 @@ def load_lora(lora, to_load):
|
|||||||
b1_name = "{}.b1.weight".format(x)
|
b1_name = "{}.b1.weight".format(x)
|
||||||
b2_name = "{}.b2.weight".format(x)
|
b2_name = "{}.b2.weight".format(x)
|
||||||
if a1_name in lora:
|
if a1_name in lora:
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||||
loaded_keys.add(a1_name)
|
loaded_keys.add(a1_name)
|
||||||
loaded_keys.add(a2_name)
|
loaded_keys.add(a2_name)
|
||||||
loaded_keys.add(b1_name)
|
loaded_keys.add(b1_name)
|
||||||
|
@ -7,6 +7,18 @@ import uuid
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
def apply_weight_decompose(dora_scale, weight):
|
||||||
|
weight_norm = (
|
||||||
|
weight.transpose(0, 1)
|
||||||
|
.reshape(weight.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return weight * (dora_scale / weight_norm)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -309,6 +321,7 @@ class ModelPatcher:
|
|||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
|
dora_scale = v[4]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
@ -318,6 +331,8 @@ class ModelPatcher:
|
|||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "lokr":
|
elif patch_type == "lokr":
|
||||||
@ -328,6 +343,7 @@ class ModelPatcher:
|
|||||||
w2_a = v[5]
|
w2_a = v[5]
|
||||||
w2_b = v[6]
|
w2_b = v[6]
|
||||||
t2 = v[7]
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
dim = None
|
dim = None
|
||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
@ -357,6 +373,8 @@ class ModelPatcher:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "loha":
|
elif patch_type == "loha":
|
||||||
@ -366,6 +384,7 @@ class ModelPatcher:
|
|||||||
alpha *= v[2] / w1b.shape[0]
|
alpha *= v[2] / w1b.shape[0]
|
||||||
w2a = v[3]
|
w2a = v[3]
|
||||||
w2b = v[4]
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: #cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
@ -386,12 +405,16 @@ class ModelPatcher:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
if v[4] is not None:
|
||||||
alpha *= v[4] / v[0].shape[0]
|
alpha *= v[4] / v[0].shape[0]
|
||||||
|
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
@ -399,6 +422,8 @@ class ModelPatcher:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user