diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 54a64aa4..837c64b0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -28,8 +28,8 @@ import comfy.model_management from comfy.types import UnetWrapperFunction -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32) +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) lora_diff *= alpha weight_calc = weight + lora_diff.type(weight.dtype) weight_norm = ( @@ -426,7 +426,7 @@ class ModelPatcher: self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) return self.model - def calculate_weight(self, patches, weight, key): + def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] v = p[1] @@ -445,7 +445,7 @@ class ModelPatcher: weight *= strength_model if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key), ) + v = (self.calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), ) if len(v) == 1: patch_type = "diff" @@ -461,8 +461,8 @@ class ModelPatcher: else: weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) + mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) + mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) dora_scale = v[4] if v[2] is not None: alpha = v[2] / mat2.shape[0] @@ -471,13 +471,13 @@ class ModelPatcher: if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32) + mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -495,23 +495,23 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32), - comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32)) + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32) + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32), - comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32)) + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) else: w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, torch.float32), - comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32), - comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32)) + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32) + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -523,7 +523,7 @@ class ModelPatcher: try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -543,24 +543,24 @@ class ModelPatcher: t1 = v[5] t2 = v[6] m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, torch.float32), - comfy.model_management.cast_to_device(w1b, weight.device, torch.float32), - comfy.model_management.cast_to_device(w1a, weight.device, torch.float32)) + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, torch.float32), - comfy.model_management.cast_to_device(w2b, weight.device, torch.float32), - comfy.model_management.cast_to_device(w2a, weight.device, torch.float32)) + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32), - comfy.model_management.cast_to_device(w1b, weight.device, torch.float32)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32), - comfy.model_management.cast_to_device(w2b, weight.device, torch.float32)) + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -573,15 +573,15 @@ class ModelPatcher: dora_scale = v[5] - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) try: lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: