mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Better subnormal fp8 stochastic rounding. Thanks Ashen.
This commit is contained in:
parent
20ace7c853
commit
4506ddc86a
@ -19,25 +19,29 @@ def manual_stochastic_round_to_float8(x, dtype):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine mantissa calculation and rounding
|
# Combine mantissa calculation and rounding
|
||||||
mantissa = abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0
|
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
||||||
mantissa_scaled = mantissa * (2**MANTISSA_BITS)
|
# zero_mask = (abs_x == 0)
|
||||||
|
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
||||||
|
normal_mask = ~(exponent == 0)
|
||||||
|
|
||||||
|
mantissa_scaled = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||||
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||||
|
)
|
||||||
mantissa_floor = mantissa_scaled.floor()
|
mantissa_floor = mantissa_scaled.floor()
|
||||||
mantissa = torch.where(
|
mantissa = torch.where(
|
||||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
||||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
||||||
mantissa_floor / (2**MANTISSA_BITS)
|
mantissa_floor / (2**MANTISSA_BITS)
|
||||||
)
|
)
|
||||||
|
result = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
||||||
|
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
||||||
|
)
|
||||||
|
|
||||||
# Combine final result calculation
|
result = torch.where(abs_x == 0, 0, result)
|
||||||
result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa)
|
|
||||||
|
|
||||||
# Handle zero case
|
|
||||||
zero_mask = (abs_x == 0)
|
|
||||||
result = torch.where(zero_mask, torch.zeros_like(result), result)
|
|
||||||
|
|
||||||
# Handle subnormal numbers
|
|
||||||
min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
|
||||||
result = torch.where((abs_x < min_normal) & (~zero_mask), torch.round(x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) * (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)), result)
|
|
||||||
return result.to(dtype=dtype)
|
return result.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user