diff --git a/comfy/ops.py b/comfy/ops.py index 2890cac0..5e7c668e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -250,6 +250,12 @@ def fp8_linear(self, input): if dtype not in [torch.float8_e4m3fn]: return None + tensor_2d = False + if len(input.shape) == 2: + tensor_2d = True + input = input.unsqueeze(1) + + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w = w.t() @@ -272,7 +278,11 @@ def fp8_linear(self, input): if isinstance(o, tuple): o = o[0] + if tensor_2d: + return o.reshape(input.shape[0], -1) + return o.reshape((-1, input.shape[1], self.weight.shape[0])) + return None class fp8_ops(manual_cast): @@ -316,7 +326,11 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return out weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + + if weight.numel() < input.numel(): #TODO: optimize + return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + else: + return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) def convert_weight(self, weight, inplace=False, **kwargs): if inplace: