Try to fix some lora issues.

This commit is contained in:
comfyanonymous 2024-08-22 15:15:47 -04:00
parent 7b70b266d8
commit c7ee4b37a1

View File

@ -20,31 +20,36 @@ import torch
import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False):
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
return weight
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
if s.bias_function is not None:
has_function = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
if s.weight_function is not None:
has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
@ -252,7 +257,8 @@ def fp8_linear(self, input):
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
@ -263,8 +269,8 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
if self.bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)