Skip to content

Commit

Permalink
Merge pull request #3 from graphcore-research/awf/encode
Browse files Browse the repository at this point in the history
Add Encode
  • Loading branch information
awf authored Apr 2, 2024
2 parents 4425e6f + 2ef20bc commit ce20d54
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 69 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ cd ..
#### Pushing
```
rm -rf dist
python3 -m build
pip install build twine
python -m build
echo __token__ | twine upload --repository pypi dist/* --verbose
```

Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ API

.. autofunction:: decode_float
.. autofunction:: round_float
.. autofunction:: encode_float

.. autoclass:: FormatInfo()
:members:
.. autoclass:: FloatClass()
Expand Down
2 changes: 1 addition & 1 deletion src/gfloat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .types import FormatInfo, FloatClass, FloatValue, RoundMode
from .decode import decode_float
from .round import round_float
from .round import round_float, encode_float
import gfloat.formats

# Don't automatically import from .formats.
Expand Down
8 changes: 3 additions & 5 deletions src/gfloat/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue:
fsignificand = 1.0 + significand * 2**-t

# val: the raw value excluding specials
val = sign * fsignificand * 2**expval
val = sign * fsignificand * 2.0**expval

# Now overwrite the raw value with specials: Infs, NaN, -0, NaN_0
signed_infinity = -np.inf if signbit else np.inf

fval = val
# All-bits-one exponent (ABOE)
# All-bits-special exponent (ABSE)
if exp == 2**w - 1:
min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans
if significand >= min_i_with_nan:
Expand Down Expand Up @@ -84,6 +84,4 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue:
else:
fclass = FloatClass.INFINITE

return FloatValue(
i, fval, val, exp, expval, significand, fsignificand, signbit, fclass, fi
)
return FloatValue(i, fval, exp, expval, significand, fsignificand, signbit, fclass)
21 changes: 21 additions & 0 deletions src/gfloat/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,24 @@ def format_info_p3109(precision: int) -> FormatInfo:
num_high_nans=0,
has_subnormals=True,
)


## Collections of formats
p3109_formats = [format_info_p3109(p) for p in range(1, 7)]

fp8_formats = [
format_info_ocp_e4m3,
format_info_ocp_e5m2,
*p3109_formats,
]

fp16_formats = [
format_info_binary16,
format_info_bfloat16,
]

all_formats = [
*fp8_formats,
*fp16_formats,
format_info_binary32,
]
104 changes: 96 additions & 8 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numpy as np
import math

from .types import FormatInfo, RoundMode
from .types import FormatInfo, RoundMode, FloatValue
from .decode import decode_float


def _isodd(v: int):
Expand Down Expand Up @@ -55,19 +56,21 @@ def round_float(fi: FormatInfo, v: float, rnd=RoundMode.TiesToEven, sat=False) -

# Extract sign
sign = np.signbit(v)
vpos = -v if sign else v

if v == 0:
if vpos < fi.smallest_subnormal / 2:
# Test against smallest_subnormal to avoid subnormals in frexp below
# Note that this restricts us to types narrower than float64
result = 0

elif np.isinf(v):
elif np.isinf(vpos):
result = np.inf

else:
# Extract significand (mantissa) and exponent
fsignificand, expval = np.frexp(np.abs(v))

fsignificand, expval = np.frexp(vpos)
assert fsignificand >= 0.5 and fsignificand < 1.0
# move significand to [1.0, 2.0)
# Bring significand into range [1.0, 2.0)
fsignificand *= 2
expval -= 1

Expand Down Expand Up @@ -100,8 +103,14 @@ def round_float(fi: FormatInfo, v: float, rnd=RoundMode.TiesToEven, sat=False) -
isignificand += 1
else:
# All other modes tie to even
if _isodd(isignificand):
isignificand += 1
if fi.precision == 1:
# No significand bits
assert (isignificand == 1) or (isignificand == 0)
if _isodd(biased_exp):
expval += 1
else:
if _isodd(isignificand):
isignificand += 1

result = isignificand * (2.0**expval)

Expand All @@ -128,3 +137,82 @@ def round_float(fi: FormatInfo, v: float, rnd=RoundMode.TiesToEven, sat=False) -
result = -result

return result


def encode_float(fi: FormatInfo, v: float) -> int:
"""
Encode input to the given :py:class:`FormatInfo`.
Will round toward zero if v is not in the value set.
Will saturate to inf, nan, fi.max in order of precedence.
Encode -0 to 0 if not fi.has_nz
For other roundings, and saturations, call round_float first.
:return: The integer code point
:rtype: int
"""

# Format Constants
k = fi.bits
p = fi.precision
t = p - 1

# Encode
if np.isnan(v):
return fi.code_of_nan

# Overflow/underflow
if v > fi.max:
return (
fi.code_of_posinf
if fi.has_infs
else fi.code_of_nan if fi.num_nans > 0 else fi.max
)
if v < fi.min:
return (
fi.code_of_neginf
if fi.has_infs
else fi.code_of_nan if fi.num_nans > 0 else fi.min
)

# Finite values
sign = np.signbit(v)
vpos = -v if sign else v

if vpos <= fi.smallest_subnormal / 2:
isig = 0
biased_exp = 0
else:
assert fi.bits < 64 # TODO: check implementation if fi is binary64
sig, exp = np.frexp(vpos)
# sig in range [0.5, 1)
sig *= 2
exp -= 1
# now sig in range [1, 2)

biased_exp = exp + fi.expBias
if biased_exp < 1:
# subnormal
sig *= 2.0 ** (biased_exp - 1)
biased_exp = 0
assert vpos == sig * 2 ** (1 - fi.expBias)
else:
if sig > 0:
sig -= 1.0

isig = math.floor(sig * 2**t)

# Zero
if isig == 0 and biased_exp == 0:
if sign and fi.has_nz:
return fi.code_of_negzero
else:
return fi.code_of_zero

# Nonzero
assert isig < 2**t
assert biased_exp < 2**fi.expBits

ival = (sign << (k - 1)) | (biased_exp << t) | (isig << 0)

return ival
Loading

0 comments on commit ce20d54

Please sign in to comment.