Fix issue with fp8 ops on some models. (#8045)

_scaled_mm errors when an input is non contiguous.
This commit is contained in:
comfyanonymous 2025-05-10 04:52:56 -07:00 committed by GitHub
parent 1b3bf0a5da
commit d42613686f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -308,10 +308,10 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)