mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Optimizations to --fast and scaled fp8.
This commit is contained in:
parent
f82314fcfc
commit
8ce2a1052c
16
comfy/ops.py
16
comfy/ops.py
@ -250,6 +250,12 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
@ -272,7 +278,11 @@ def fp8_linear(self, input):
|
|||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
|
if tensor_2d:
|
||||||
|
return o.reshape(input.shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
@ -316,7 +326,11 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
|
||||||
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
if inplace:
|
if inplace:
|
||||||
|
Loading…
Reference in New Issue
Block a user