mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
calculate_weight function to use a different dtype.
This commit is contained in:
parent
83f343146a
commit
cd5017c1c9
@ -28,8 +28,8 @@ import comfy.model_management
|
|||||||
from comfy.types import UnetWrapperFunction
|
from comfy.types import UnetWrapperFunction
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
@ -426,7 +426,7 @@ class ModelPatcher:
|
|||||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
@ -445,7 +445,7 @@ class ModelPatcher:
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
v = (self.calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
@ -461,8 +461,8 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
dora_scale = v[4]
|
dora_scale = v[4]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha = v[2] / mat2.shape[0]
|
alpha = v[2] / mat2.shape[0]
|
||||||
@ -471,13 +471,13 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -495,23 +495,23 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@ -523,7 +523,7 @@ class ModelPatcher:
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -543,24 +543,24 @@ class ModelPatcher:
|
|||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -573,15 +573,15 @@ class ModelPatcher:
|
|||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user