mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Do lora cast on GPU instead of CPU for higher performance.
This commit is contained in:
parent
0109431626
commit
b92bf8196e
@ -187,13 +187,13 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||||
elif len(v) == 4: #lora/locon
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0].float().to(weight.device)
|
mat1 = v[0].to(weight.device).float()
|
||||||
mat2 = v[1].float().to(weight.device)
|
mat2 = v[1].to(weight.device).float()
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = v[3].float().to(weight.device)
|
mat3 = v[3].to(weight.device).float()
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
@ -212,18 +212,18 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float())
|
||||||
else:
|
else:
|
||||||
w1 = w1.float().to(weight.device)
|
w1 = w1.to(weight.device).float()
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float())
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float())
|
||||||
else:
|
else:
|
||||||
w2 = w2.float().to(weight.device)
|
w2 = w2.to(weight.device).float()
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@ -244,11 +244,11 @@ class ModelPatcher:
|
|||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: #cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float())
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float())
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float())
|
||||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user