Skip to content

Commit

Permalink
Merge pull request #24 from graphcore-research/simplify-round
Browse files Browse the repository at this point in the history
Simplify round, fix directed rounding
  • Loading branch information
awf authored Jun 6, 2024
2 parents b82ad80 + 1ecf411 commit aea875d
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 57 deletions.
97 changes: 50 additions & 47 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def round_float(
# Constants
p = fi.precision
bias = fi.expBias
t = p - 1

if np.isnan(v):
if fi.num_nans == 0:
Expand All @@ -56,59 +55,57 @@ def round_float(
if np.isinf(vpos):
result = np.inf

elif fi.has_subnormals and vpos < fi.smallest_subnormal / 2:
# Test against smallest_subnormal to avoid subnormals in frexp below
# Note that this restricts us to types narrower than float64
result = 0.0
elif vpos == 0:
result = 0

else:
# Extract significand (mantissa) and exponent
fsignificand, expval = np.frexp(vpos)
assert fsignificand >= 0.5 and fsignificand < 1.0
# Bring significand into range [1.0, 2.0)
fsignificand *= 2
expval -= 1
# Extract exponent
expval = int(np.floor(np.log2(vpos)))

assert expval > -1024 + p # not yet tested for float64 near-subnormals

# Effective precision, accounting for right shift for subnormal values
biased_exp = expval + bias
if fi.has_subnormals:
effective_precision = t + min(biased_exp - 1, 0)
else:
effective_precision = t
expval = max(expval, 1 - bias)

# Lift to "integer * 2^e"
fsignificand *= 2.0**effective_precision
expval -= effective_precision
expval = expval - p + 1

fsignificand = vpos * 2.0**-expval

# Round
isignificand = math.floor(fsignificand)
if isignificand != fsignificand:
# Need to round
if rnd == RoundMode.TowardZero:
pass
elif rnd == RoundMode.TowardPositive:
isignificand += 1 if not sign else 0
elif rnd == RoundMode.TowardNegative:
isignificand += 1 if sign else 0
else:
# Round to nearest
d = fsignificand - isignificand
if d > 0.5:
isignificand += 1
elif d == 0.5:
# Tie
if rnd == RoundMode.TiesToAway:
isignificand += 1
else:
# All other modes tie to even
if fi.precision == 1:
# No significand bits
assert (isignificand == 1) or (isignificand == 0)
if _isodd(biased_exp):
expval += 1
else:
if _isodd(isignificand):
isignificand += 1
delta = fsignificand - isignificand
if (
(rnd == RoundMode.TowardPositive and not sign and delta > 0)
or (rnd == RoundMode.TowardNegative and sign and delta > 0)
or (rnd == RoundMode.TiesToAway and delta >= 0.5)
or (rnd == RoundMode.TiesToEven and delta > 0.5)
or (rnd == RoundMode.TiesToEven and delta == 0.5 and _isodd(isignificand))
):
isignificand += 1

## Special case for Precision=1, all-log format with zero.
if fi.precision == 1:
# The logic is simply duplicated for clarity of reading.
isignificand = math.floor(fsignificand)
code_is_odd = isignificand != 0 and _isodd(expval + bias)
if (
(rnd == RoundMode.TowardPositive and not sign and delta > 0)
or (rnd == RoundMode.TowardNegative and sign and delta > 0)
or (rnd == RoundMode.TiesToAway and delta >= 0.5)
or (rnd == RoundMode.TiesToEven and delta > 0.5)
or (rnd == RoundMode.TiesToEven and delta == 0.5 and code_is_odd)
):
# Go to nextUp.
# Increment isignificand if zero,
# else increment exponent
if isignificand == 0:
isignificand = 1
else:
assert isignificand == 1
expval += 1
## End special case for Precision=1.

result = isignificand * (2.0**expval)

Expand All @@ -119,9 +116,15 @@ def round_float(
return 0.0

# Overflow
if result > (-fi.min if sign else fi.max):
if sat:
result = fi.max
amax = -fi.min if sign else fi.max
if result > amax:
if (
sat
or (rnd == RoundMode.TowardNegative and not sign and np.isfinite(v))
or (rnd == RoundMode.TowardPositive and sign and np.isfinite(v))
or (rnd == RoundMode.TowardZero and np.isfinite(v))
):
result = amax
else:
if fi.has_infs:
result = np.inf
Expand Down
4 changes: 2 additions & 2 deletions src/gfloat/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def bits(self) -> int:
@property
def eps(self) -> float:
"""
The difference between 1.0 and the next smallest representable float
The difference between 1.0 and the smallest representable float
larger than 1.0. For example, for 64-bit binary floats in the IEEE-754
standard, ``eps = 2**-52``, approximately 2.22e-16.
"""
Expand All @@ -156,7 +156,7 @@ def eps(self) -> float:
@property
def epsneg(self) -> float:
"""
The difference between 1.0 and the next smallest representable float
The difference between 1.0 and the largest representable float
less than 1.0. For example, for 64-bit binary floats in the IEEE-754
standard, ``epsneg = 2**-53``, approximately 1.11e-16.
"""
Expand Down
Loading

0 comments on commit aea875d

Please sign in to comment.