mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use less memory in float8 lora patching by doing calculations in fp16.
This commit is contained in:
parent
c6812947e9
commit
7985ff88b9
@ -1,4 +1,15 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
|
||||
mantissa_scaled += torch.rand_like(mantissa_scaled)
|
||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype):
|
||||
@ -9,40 +20,30 @@ def manual_stochastic_round_to_float8(x, dtype):
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
x = x.half()
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
sign = torch.where(abs_x == 0, 0, sign)
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
||||
# zero_mask = (abs_x == 0)
|
||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
mantissa_floor = mantissa_scaled.floor()
|
||||
mantissa = torch.where(
|
||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
||||
mantissa_floor / (2**MANTISSA_BITS)
|
||||
)
|
||||
result = torch.where(
|
||||
normal_mask,
|
||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
||||
)
|
||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS)
|
||||
|
||||
result = torch.where(abs_x == 0, 0, result)
|
||||
return result.to(dtype=dtype)
|
||||
sign *= torch.where(
|
||||
normal_mask,
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
del abs_x
|
||||
|
||||
return sign.to(dtype=dtype)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user