From 1955e198636a0bf547a324bdf8bcd0b8327159c5 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Fri, 22 Mar 2024 18:05:06 +0000 Subject: [PATCH 1/3] Remove unused fields --- src/gfloat/decode.py | 4 +-- src/gfloat/types.py | 71 ++++++++++++++++++++------------------------ 2 files changed, 33 insertions(+), 42 deletions(-) diff --git a/src/gfloat/decode.py b/src/gfloat/decode.py index 4c30325..efc8d1f 100644 --- a/src/gfloat/decode.py +++ b/src/gfloat/decode.py @@ -84,6 +84,4 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue: else: fclass = FloatClass.INFINITE - return FloatValue( - i, fval, val, exp, expval, significand, fsignificand, signbit, fclass, fi - ) + return FloatValue(i, fval, exp, expval, significand, fsignificand, signbit, fclass) diff --git a/src/gfloat/types.py b/src/gfloat/types.py index 13ca2c2..671f839 100644 --- a/src/gfloat/types.py +++ b/src/gfloat/types.py @@ -19,6 +19,38 @@ class RoundMode(Enum): TiesToAway = 5 #: Round to nearest, ties away from zero +class FloatClass(Enum): + """ + Enum for the classification of a FloatValue. + """ + + NORMAL = 1 #: A positive or negative normalized non-zero value + SUBNORMAL = 2 #: A positive or negative subnormal value + ZERO = 3 #: A positive or negative zero value + INFINITE = 4 #: A positive or negative infinity (+/-Inf) + NAN = 5 #: Not a Number (NaN) + + +@dataclass +class FloatValue: + """ + A floating-point value decoded in great detail. + """ + + ival: int #: Integer code point + + #: Value. Assumed to be exactly round-trippable to python float. + #: This is true for all <64bit formats known in 2023. + fval: float + + exp: int #: Raw exponent without bias + expval: int #: Exponent, bias subtracted + significand: int #: Significand as an integer + fsignificand: float #: Significand as a float in the range [0,2) + signbit: int #: Sign bit: 1 => negative, 0 => positive + fclass: FloatClass #: See FloatClass + + @dataclass class FormatInfo: """ @@ -233,42 +265,3 @@ def min(self) -> float: # The smallest positive floating point number with 0 as leading bit in # the mantissa following IEEE-754. # """ - - -class FloatClass(Enum): - """ - Enum for the classification of a FloatValue. - """ - - NORMAL = 1 #: A positive or negative normalized non-zero value - SUBNORMAL = 2 #: A positive or negative subnormal value - ZERO = 3 #: A positive or negative zero value - INFINITE = 4 #: A positive or negative infinity (+/-Inf) - NAN = 5 #: Not a Number (NaN) - - -@dataclass -class FloatValue: - """ - A floating-point value decoded in great detail. - """ - - ival: int #: Integer code point - - #: Value. Assumed to be exactly round-trippable to python float. - #: This is true for all <64bit formats known in 2023. - fval: float - - val_raw: float #: Value, assuming all code points finite - exp: int #: Raw exponent without bias - expval: int #: Exponent, bias subtracted - significand: int #: Significand as an integer - fsignificand: float #: Significand as a float in the range [0,2) - signbit: int #: Sign bit: 1 => negative, 0 => positive - fclass: FloatClass #: See FloatClass - fi: FormatInfo # Backlink to FormatInfo - - @property - def signstr(self): - """Return "+" or "-" according to signbit""" - return "-" if self.signbit else "+" From 878726adea24ec748bddda31878a033501326f00 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Mon, 25 Mar 2024 10:30:00 +0000 Subject: [PATCH 2/3] Add encode --- README.md | 3 +- src/gfloat/__init__.py | 2 +- src/gfloat/decode.py | 4 +- src/gfloat/formats.py | 21 +++++++++ src/gfloat/round.py | 104 +++++++++++++++++++++++++++++++++++++---- src/gfloat/types.py | 94 +++++++++++++++++++++++++++++++------ test/test_decode.py | 28 +++++++++++ test/test_encode.py | 27 +++++++++++ test/test_finfo.py | 6 +++ test/test_round.py | 5 +- 10 files changed, 266 insertions(+), 28 deletions(-) create mode 100644 test/test_encode.py diff --git a/README.md b/README.md index 65b0e1f..7806895 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,8 @@ cd .. #### Pushing ``` rm -rf dist -python3 -m build +pip install build twine +python -m build echo __token__ | twine upload --repository pypi dist/* --verbose ``` diff --git a/src/gfloat/__init__.py b/src/gfloat/__init__.py index b4463bc..dc33b85 100644 --- a/src/gfloat/__init__.py +++ b/src/gfloat/__init__.py @@ -2,7 +2,7 @@ from .types import FormatInfo, FloatClass, FloatValue, RoundMode from .decode import decode_float -from .round import round_float +from .round import round_float, encode_float import gfloat.formats # Don't automatically import from .formats. diff --git a/src/gfloat/decode.py b/src/gfloat/decode.py index efc8d1f..8b2d78a 100644 --- a/src/gfloat/decode.py +++ b/src/gfloat/decode.py @@ -49,13 +49,13 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue: fsignificand = 1.0 + significand * 2**-t # val: the raw value excluding specials - val = sign * fsignificand * 2**expval + val = sign * fsignificand * 2.0**expval # Now overwrite the raw value with specials: Infs, NaN, -0, NaN_0 signed_infinity = -np.inf if signbit else np.inf fval = val - # All-bits-one exponent (ABOE) + # All-bits-special exponent (ABSE) if exp == 2**w - 1: min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans if significand >= min_i_with_nan: diff --git a/src/gfloat/formats.py b/src/gfloat/formats.py index 589d1cb..a179e53 100644 --- a/src/gfloat/formats.py +++ b/src/gfloat/formats.py @@ -91,3 +91,24 @@ def format_info_p3109(precision: int) -> FormatInfo: num_high_nans=0, has_subnormals=True, ) + + +## Collections of formats +p3109_formats = [format_info_p3109(p) for p in range(1, 7)] + +fp8_formats = [ + format_info_ocp_e4m3, + format_info_ocp_e5m2, + *p3109_formats, +] + +fp16_formats = [ + format_info_binary16, + format_info_bfloat16, +] + +all_formats = [ + *fp8_formats, + *fp16_formats, + format_info_binary32, +] diff --git a/src/gfloat/round.py b/src/gfloat/round.py index d70f4f8..2f18035 100644 --- a/src/gfloat/round.py +++ b/src/gfloat/round.py @@ -4,7 +4,8 @@ import numpy as np import math -from .types import FormatInfo, RoundMode +from .types import FormatInfo, RoundMode, FloatValue +from .decode import decode_float def _isodd(v: int): @@ -55,19 +56,21 @@ def round_float(fi: FormatInfo, v: float, rnd=RoundMode.TiesToEven, sat=False) - # Extract sign sign = np.signbit(v) + vpos = -v if sign else v - if v == 0: + if vpos < fi.smallest_subnormal / 2: + # Test against smallest_subnormal to avoid subnormals in frexp below + # Note that this restricts us to types narrower than float64 result = 0 - elif np.isinf(v): + elif np.isinf(vpos): result = np.inf else: # Extract significand (mantissa) and exponent - fsignificand, expval = np.frexp(np.abs(v)) - + fsignificand, expval = np.frexp(vpos) assert fsignificand >= 0.5 and fsignificand < 1.0 - # move significand to [1.0, 2.0) + # Bring significand into range [1.0, 2.0) fsignificand *= 2 expval -= 1 @@ -100,8 +103,14 @@ def round_float(fi: FormatInfo, v: float, rnd=RoundMode.TiesToEven, sat=False) - isignificand += 1 else: # All other modes tie to even - if _isodd(isignificand): - isignificand += 1 + if fi.precision == 1: + # No significand bits + assert (isignificand == 1) or (isignificand == 0) + if _isodd(biased_exp): + expval += 1 + else: + if _isodd(isignificand): + isignificand += 1 result = isignificand * (2.0**expval) @@ -128,3 +137,82 @@ def round_float(fi: FormatInfo, v: float, rnd=RoundMode.TiesToEven, sat=False) - result = -result return result + + +def encode_float(fi: FormatInfo, v: float) -> int: + """ + Encode input to the given :py:class:`FormatInfo`. + + Will round toward zero if v is not in the value set. + Will saturate to inf, nan, fi.max in order of precedence. + Encode -0 to 0 if not fi.has_nz + For other roundings, and saturations, call round_float first. + + :return: The integer code point + :rtype: int + """ + + # Format Constants + k = fi.bits + p = fi.precision + t = p - 1 + + # Encode + if np.isnan(v): + return fi.code_of_nan + + # Overflow/underflow + if v > fi.max: + return ( + fi.code_of_posinf + if fi.has_infs + else fi.code_of_nan if fi.num_nans > 0 else fi.max + ) + if v < fi.min: + return ( + fi.code_of_neginf + if fi.has_infs + else fi.code_of_nan if fi.num_nans > 0 else fi.min + ) + + # Finite values + sign = np.signbit(v) + vpos = -v if sign else v + + if vpos <= fi.smallest_subnormal / 2: + isig = 0 + biased_exp = 0 + else: + assert fi.bits < 64 # TODO: check implementation if fi is binary64 + sig, exp = np.frexp(vpos) + # sig in range [0.5, 1) + sig *= 2 + exp -= 1 + # now sig in range [1, 2) + + biased_exp = exp + fi.expBias + if biased_exp < 1: + # subnormal + sig *= 2.0 ** (biased_exp - 1) + biased_exp = 0 + assert vpos == sig * 2 ** (1 - fi.expBias) + else: + if sig > 0: + sig -= 1.0 + + isig = math.floor(sig * 2**t) + + # Zero + if isig == 0 and biased_exp == 0: + if sign and fi.has_nz: + return fi.code_of_negzero + else: + return fi.code_of_zero + + # Nonzero + assert isig < 2**t + assert biased_exp < 2**fi.expBits + + ival = (sign << (k - 1)) | (biased_exp << t) | (isig << 0) + + return ival diff --git a/src/gfloat/types.py b/src/gfloat/types.py index 671f839..640eedd 100644 --- a/src/gfloat/types.py +++ b/src/gfloat/types.py @@ -117,14 +117,6 @@ def expBias(self): exp_for_emax = 2**self.expBits - (2 if all_bits_one_full else 1) return exp_for_emax - self.emax - @property - def num_nans(self): - """The number of code points which decode to NaN""" - return (0 if self.has_nz else 1) + 2 * self.num_high_nans - - def __str__(self): - return f"{self.name}" - # numpy finfo properties @property def bits(self) -> int: @@ -207,6 +199,75 @@ def min(self) -> float: """ return -self.max + @property + def num_nans(self): + """ + The number of code points which decode to NaN + """ + return (0 if self.has_nz else 1) + 2 * self.num_high_nans + + @property + def code_of_nan(self) -> int: + """ + Return a codepoint for a NaN + """ + if self.num_high_nans > 0: + return 2 ** (self.k) - 1 + if not self.has_nz: + return 2 ** (self.k - 1) + raise ValueError(f"No NaN in {self}") + + @property + def code_of_posinf(self) -> int: + """ + Return a codepoint for positive infinity + """ + if not self.has_infs: + raise ValueError(f"No Inf in {self}") + + return 2 ** (self.k - 1) - 1 - self.num_high_nans + + @property + def code_of_neginf(self) -> int: + """ + Return a codepoint for negative infinity + """ + if not self.has_infs: + raise ValueError(f"No Inf in {self}") + + return 2**self.k - 1 - self.num_high_nans + + @property + def code_of_zero(self) -> int: + """ + Return a codepoint for (non-negative) zero + """ + return 0 + + @property + def code_of_negzero(self) -> int: + """ + Return a codepoint for negative zero + """ + if not self.has_nz: + raise ValueError(f"No negative zero in {self}") + + return 2 ** (self.k - 1) + + @property + def code_of_max(self) -> int: + """ + Return a codepoint for fi.max + """ + return 2 ** (self.k - 1) - self.num_high_nans - self.has_infs - 1 + + @property + def code_of_min(self) -> int: + """ + Return a codepoint for fi.max + """ + return 2**self.k - self.num_high_nans - self.has_infs - 1 + # @property # def minexp(self) -> int: # """ @@ -259,9 +320,14 @@ def min(self) -> float: # the mantissa following IEEE-754 (see Notes). # """ - # @property - # def smallest_subnormal(self) -> float: - # """ - # The smallest positive floating point number with 0 as leading bit in - # the mantissa following IEEE-754. - # """ + @property + def smallest_subnormal(self) -> float: + """ + The smallest positive floating point number with 0 as leading bit in + the mantissa following IEEE-754. + """ + assert self.has_subnormals, "not implemented" + return 2 ** -(self.expBias + self.tSignificandBits - 1) + + def __str__(self): + return f"{self.name}" diff --git a/test/test_decode.py b/test/test_decode.py index 0518f28..6cc82ad 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -84,6 +84,34 @@ def test_spot_check_bfloat16(): assert np.isnan(dec(0x7FFF)) +@pytest.mark.parametrize("fi", p3109_formats, ids=str) +def test_specials(fi): + assert fi.code_of_nan == 0x80 + assert fi.code_of_zero == 0x00 + assert fi.code_of_posinf == 0x7F + assert fi.code_of_neginf == 0xFF + + +@pytest.mark.parametrize("fi", all_formats, ids=str) +def test_specials_decode(fi): + dec = lambda v: decode_float(fi, v).fval + + assert dec(fi.code_of_zero) == 0 + + if fi.num_nans > 0: + assert np.isnan(dec(fi.code_of_nan)) + + if fi.has_infs: + assert dec(fi.code_of_posinf) == np.inf + assert dec(fi.code_of_neginf) == -np.inf + + assert dec(fi.code_of_max) == fi.max + assert dec(fi.code_of_min) == fi.min + + if fi.has_subnormals: + assert dec(1) == fi.smallest_subnormal + + @pytest.mark.parametrize( "fmt,npfmt,int_dtype", [ diff --git a/test/test_encode.py b/test/test_encode.py new file mode 100644 index 0000000..eeb8419 --- /dev/null +++ b/test/test_encode.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. + +import pytest +import ml_dtypes +import numpy as np + +from gfloat import decode_float, encode_float +from gfloat.formats import * + + +@pytest.mark.parametrize("fi", all_formats, ids=str) +def test_encode(fi): + dec = lambda v: decode_float(fi, v).fval + + if fi.bits <= 8: + step = 1 + elif fi.bits <= 16: + step = 13 + elif fi.bits <= 32: + step = 73013 + + for i in range(0, 2**fi.bits, step): + fv = decode_float(fi, i) + ival = encode_float(fi, fv.fval) + fv2 = decode_float(fi, ival) + assert (i == ival) or np.isnan(fv.fval) + np.testing.assert_equal(fv2.fval, fv.fval) diff --git a/test/test_finfo.py b/test/test_finfo.py index 720b79f..40c9ecc 100644 --- a/test/test_finfo.py +++ b/test/test_finfo.py @@ -24,3 +24,9 @@ def test_finfo(fmt, npfmt): assert fmt.epsneg == ml_dtypes.finfo(npfmt).epsneg assert fmt.max == ml_dtypes.finfo(npfmt).max assert fmt.maxexp == ml_dtypes.finfo(npfmt).maxexp + + +def test_constants(): + assert format_info_p3109(1).smallest_subnormal == 2.0**-62 + assert format_info_p3109(4).smallest_subnormal == 2.0**-10 + assert format_info_p3109(7).smallest_subnormal == 2.0**-6 diff --git a/test/test_round.py b/test/test_round.py index 5e438a7..82f5465 100644 --- a/test/test_round.py +++ b/test/test_round.py @@ -17,7 +17,10 @@ def _mlround(v, dty): def test_round_p3109(): fi = format_info_p3109(4) + assert round_float(fi, 0.0068359375) == 0.0068359375 assert round_float(fi, 0.0029296875) == 0.0029296875 + assert round_float(fi, 0.0078125) == 0.0078125 + assert round_float(fi, 0.017578125) == 0.017578125 assert round_float(fi, 224.0) == 224.0 assert round_float(fi, 240.0) == np.inf @@ -88,8 +91,6 @@ def test_round_e4m3(): assert np.isnan(round_float(fi, np.nan, sat=True)) -p3109_formats = [format_info_p3109(p) for p in range(2, 7)] - some_positive_codepoints = ( 0x00, 0x01, From 2ef20bcb9454a301bc900e51b35a90dad6a90226 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Mon, 25 Mar 2024 12:13:57 +0000 Subject: [PATCH 3/3] Add encode to index --- docs/source/index.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/index.rst b/docs/source/index.rst index c73eca1..120c13f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,6 +33,8 @@ API .. autofunction:: decode_float .. autofunction:: round_float +.. autofunction:: encode_float + .. autoclass:: FormatInfo() :members: .. autoclass:: FloatClass()