clamp input

This commit is contained in:
Haoming 2024-12-05 17:32:22 +08:00
parent 9a616b81c1
commit e21b688e72

View File

@ -269,7 +269,7 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)