mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
Finalize the modularized weight adapter impl
* LoRA/LoHa/LoKr/GLoRA working well * Removed TONS of code in lora.py
This commit is contained in:
parent
889f94773a
commit
e8f3bc5ab7
180
comfy/lora.py
180
comfy/lora.py
@ -282,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad a tensor to a new shape with zeros.
|
Pad a tensor to a new shape with zeros.
|
||||||
@ -358,14 +338,13 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
|
|
||||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||||
if output is not None:
|
if output is None:
|
||||||
|
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||||
|
else:
|
||||||
weight = output
|
weight = output
|
||||||
if old_weight is not None:
|
if old_weight is not None:
|
||||||
weight = old_weight
|
weight = old_weight
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
#Fallback when calculate_weight haven't implemented
|
|
||||||
v = v.weights
|
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
@ -393,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
reshape = v[5]
|
|
||||||
|
|
||||||
if reshape is not None:
|
|
||||||
weight = pad_tensor_to_shape(weight, reshape)
|
|
||||||
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
old_glora = False
|
|
||||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
||||||
rank = v[0].shape[0]
|
|
||||||
old_glora = True
|
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
old_glora = False
|
|
||||||
rank = v[1].shape[0]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / rank
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if old_glora:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
||||||
else:
|
|
||||||
if weight.dim() > 2:
|
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
else:
|
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
||||||
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .base import WeightAdapterBase
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
class GLoRAAdapter(WeightAdapterBase):
|
class GLoRAAdapter(WeightAdapterBase):
|
||||||
@ -27,7 +29,7 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
b1_name = "{}.b1.weight".format(x)
|
b1_name = "{}.b1.weight".format(x)
|
||||||
b2_name = "{}.b2.weight".format(x)
|
b2_name = "{}.b2.weight".format(x)
|
||||||
if a1_name in lora:
|
if a1_name in lora:
|
||||||
weights = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||||
loaded_keys.add(a1_name)
|
loaded_keys.add(a1_name)
|
||||||
loaded_keys.add(a2_name)
|
loaded_keys.add(a2_name)
|
||||||
loaded_keys.add(b1_name)
|
loaded_keys.add(b1_name)
|
||||||
@ -47,4 +49,45 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
intermediate_dtype=torch.float32,
|
intermediate_dtype=torch.float32,
|
||||||
original_weight=None,
|
original_weight=None,
|
||||||
):
|
):
|
||||||
pass
|
v = self.weights
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .base import WeightAdapterBase
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
@ -38,7 +40,7 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
loaded_keys.add(hada_t1_name)
|
loaded_keys.add(hada_t1_name)
|
||||||
loaded_keys.add(hada_t2_name)
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
weights = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||||
loaded_keys.add(hada_w1_a_name)
|
loaded_keys.add(hada_w1_a_name)
|
||||||
loaded_keys.add(hada_w1_b_name)
|
loaded_keys.add(hada_w1_b_name)
|
||||||
loaded_keys.add(hada_w2_a_name)
|
loaded_keys.add(hada_w2_a_name)
|
||||||
@ -58,4 +60,41 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
intermediate_dtype=torch.float32,
|
intermediate_dtype=torch.float32,
|
||||||
original_weight=None,
|
original_weight=None,
|
||||||
):
|
):
|
||||||
pass
|
v = self.weights
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .base import WeightAdapterBase
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
@ -66,7 +68,7 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
loaded_keys.add(lokr_t2_name)
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
weights = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||||
return cls(loaded_keys, weights)
|
return cls(loaded_keys, weights)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -82,4 +84,50 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
intermediate_dtype=torch.float32,
|
intermediate_dtype=torch.float32,
|
||||||
original_weight=None,
|
original_weight=None,
|
||||||
):
|
):
|
||||||
pass
|
v = self.weights
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user