Optimizations to --fast and scaled fp8.

This commit is contained in:
comfyanonymous 2024-10-22 02:12:28 -04:00
parent f82314fcfc
commit 8ce2a1052c

View File

@ -250,6 +250,12 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]: if dtype not in [torch.float8_e4m3fn]:
return None return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
if len(input.shape) == 3: if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t() w = w.t()
@ -272,7 +278,11 @@ def fp8_linear(self, input):
if isinstance(o, tuple): if isinstance(o, tuple):
o = o[0] o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0])) return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None return None
class fp8_ops(manual_cast): class fp8_ops(manual_cast):
@ -316,7 +326,11 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return out return out
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
if inplace: if inplace: