Handle subnormal numbers in float8 rounding.

This commit is contained in:
comfyanonymous 2024-08-19 05:19:59 -04:00
parent 39f114c44b
commit 22ec02afc0

View File

@ -32,8 +32,12 @@ def manual_stochastic_round_to_float8(x, dtype):
result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa)
# Handle zero case
result = torch.where(abs_x == 0, torch.zeros_like(result), result)
zero_mask = (abs_x == 0)
result = torch.where(zero_mask, torch.zeros_like(result), result)
# Handle subnormal numbers
min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
result = torch.where((abs_x < min_normal) & (~zero_mask), torch.round(x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) * (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)), result)
return result.to(dtype=dtype)