Skip to content

Commit

Permalink
Merge pull request #8 from graphcore-research/Add-E8M0
Browse files Browse the repository at this point in the history
Add block formats (e.g. OCP MX)
Add OCP E8M0 type. This is an unsigned format, so we add is_signed to FormatInfo.
Add OCP INT8 type. This represents the significand as twos-complement, so we add is_twos_complement to FormatInfo.
  • Loading branch information
awf authored May 1, 2024
2 parents e16ac85 + 9273680 commit 6c13686
Show file tree
Hide file tree
Showing 14 changed files with 488 additions and 77 deletions.
44 changes: 22 additions & 22 deletions docs/source/01-decode.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -39,16 +39,16 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FormatInfo(name='ocp_e5m2', k=8, precision=3, emax=15, has_nz=True, has_infs=True, num_high_nans=3, has_subnormals=True)"
"FormatInfo(name='ocp_e5m2', k=8, precision=3, emax=15, has_nz=True, has_infs=True, num_high_nans=3, has_subnormals=True, is_signed=True, is_twos_complement=False)"
]
},
"execution_count": 2,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -258,7 +258,7 @@
"[256 rows x 7 columns]"
]
},
"execution_count": 3,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -281,22 +281,22 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exponent bias 7 15\n",
"emax 8 15\n",
"Infinities 0 2\n",
"Number of NaNs 2 6\n",
"Number of zeros 2 2\n",
"Max normal number 448.0 57344.0\n",
"Min normal number 0.015625 6.103515625e-05\n",
"Min subnormal number 0.001953125 1.52587890625e-05\n",
"Dynamic range (binades) 18 32\n"
"Exponent bias 7 15 16\n",
"emax 8 15 15\n",
"Infinities 0 2 2\n",
"Number of NaNs 2 6 1\n",
"Number of zeros 2 2 1\n",
"Max normal number 448.0 57344.0 49152.0\n",
"Min normal number 0.015625 6.103515625e-05 3.0517578125e-05\n",
"Min subnormal number 0.001953125 1.52587890625e-05 7.62939453125e-06\n",
"Dynamic range (binades) 18 32 33\n"
]
}
],
Expand All @@ -306,17 +306,17 @@
"\n",
"\n",
"for prop, probe in (\n",
" (\"Max exponent (emax) \", lambda fi: fi.emax),\n",
" (\"Exponent bias \", lambda fi: fi.expBias),\n",
" (\"emax \", lambda fi: fi.emax),\n",
" (\"Infinities \", lambda fi: 2 * fi.has_infs),\n",
" (\"Infinities \", lambda fi: 2 * int(fi.has_infs)),\n",
" (\"Number of NaNs \", lambda fi: fi.num_nans),\n",
" (\"Number of zeros \", lambda fi: 1 + fi.has_nz),\n",
" (\"Number of zeros \", lambda fi: int(fi.has_zero) + int(fi.has_nz)),\n",
" (\"Max normal number \", lambda fi: fi.max),\n",
" (\"Min normal number \", lambda fi: fi.smallest_normal),\n",
" (\"Min subnormal number \", lambda fi: fi.smallest_subnormal),\n",
" (\"Dynamic range (binades)\", lambda x: round(compute_dynamic_range(x))),\n",
"):\n",
" print(f\"{prop} {probe(format_info_ocp_e4m3):<20} {probe(format_info_ocp_e5m2)}\")"
" print(f\"{prop} {probe(format_info_ocp_e4m3):<20} {probe(format_info_ocp_e5m2):<20} {probe(format_info_p3109(3))}\")"
]
},
{
Expand All @@ -331,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -350,7 +350,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ API
.. autofunction:: round_float
.. autofunction:: encode_float

.. autofunction:: decode_block
.. autofunction:: encode_block


.. autoclass:: FormatInfo()
:members:
.. autoclass:: FloatClass()
Expand Down
19 changes: 18 additions & 1 deletion docs/source/formats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,29 @@ Defined Formats

.. module:: gfloat.formats

IEEE 754 Formats
----------------

.. autodata:: format_info_binary32
.. autodata:: format_info_binary16

BFloat16
----------------

.. autodata:: format_info_bfloat16

Open Compute Platform (OCP) Formats
-----------------------------------

.. autodata:: format_info_ocp_e5m2
.. autodata:: format_info_ocp_e4m3
.. autofunction:: format_info_p3109
.. autodata:: format_info_ocp_e3m2
.. autodata:: format_info_ocp_e2m3
.. autodata:: format_info_ocp_e2m1
.. autodata:: format_info_ocp_e8m0
.. autodata:: format_info_ocp_int8

IEEE WG P3109 Formats
---------------------

.. autofunction:: format_info_p3109
12 changes: 8 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
.. note::

Check the version number of this documentation against the `gfloat` version
you are using. "Latest" refers to the head on https://github.com/graphcore-research/gfloat,
while pypi versions installed using `pip install` will have corresponding `vX.Y.Z` tags.

GFloat: Generic floating point formats in Python
================================================


GFloat is designed to allow experimentation with a variety of floating-point
formats in Python. Formats are parameterized by the primary IEEE-754 parameters
of:
Expand All @@ -12,16 +16,16 @@ of:
* Maximum exponent (emax)

with additional fields defining the encoding of infinities, Not-a-number (NaN) values,
and negative zero, among others.
and negative zero, among others (see :class:`gfloat.FormatInfo`.)

This allows an implementation of generic floating point encode/decode logic,
handling various current and proposed floating point types:

- `IEEE 754 <https://en.wikipedia.org/wiki/IEEE_754>`_: Binary16, Binary32
- `OCP Float8 <https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf>`_: E5M2, E4M3
- `OCP Float8 <https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf>`_: E5M2, E4M3, and MX formats
- `IEEE WG P3109 <https://github.com/awf/P3109-Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20report.pdf>`_: P{p} for p in 1..7

The library strongly favours readability and extensibility over speed - for fast
The library favours readability and extensibility over speed - for fast
implementations of these datatypes see, for example,
`ml_dtypes <https://github.com/jax-ml/ml_dtypes>`_,
`bitstring <https://github.com/scott-griffiths/bitstring>`_,
Expand Down
1 change: 1 addition & 0 deletions src/gfloat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

from .block import BlockFormatInfo, decode_block, encode_block
from .decode import decode_float
from .round import encode_float, round_float
from .types import FloatClass, FloatValue, FormatInfo, RoundMode
Expand Down
120 changes: 120 additions & 0 deletions src/gfloat/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

# Block floating point formats
# https://en.wikipedia.org/wiki/Block_floating_point

from dataclasses import dataclass
from typing import Iterable, Iterator

from .decode import decode_float
from .round import encode_float, round_float
from .types import FloatValue, FormatInfo


@dataclass
class BlockFormatInfo:

#: Short name for the format, e.g. BlockFP8
name: str

#: Element data type
etype: FormatInfo

#: Scaling block size
k: int

#: Scale datatype
stype: FormatInfo

#: ## Derived values

@property
def element_bits(self) -> int:
"""The number of bits in each element, d"""
return self.etype.k

@property
def scale_bits(self) -> int:
"""The number of bits in the scale, w"""
return self.stype.k

@property
def block_size_bytes(self) -> int:
"""The number of bytes in a block"""
bits = self.element_bits * self.k + self.scale_bits
assert bits % 8 == 0
return bits // 8

def __str__(self):
return f"{self.name}"


def decode_block(fi: BlockFormatInfo, block: Iterable[int]) -> Iterable[float]:
"""
Decode a :paramref:`block` of integer codepoints in Block Format :paramref:`fi`
The scale is encoded in the first value of :paramref:`block`,
with the remaining values encoding the block elements.
The size of the iterable is not checked against the format descriptor.
:param fi: Describes the block format
:type fi: BlockFormatInfo
:param block: Input block
:type block: Iterable[int]
:return: A sequence of floats representing the encoded values.
:rtype: Iterable[float]
"""
it = iter(block)

scale_encoding = next(it)
scale = decode_float(fi.stype, scale_encoding).fval

for val_encoding in it:
val = scale * decode_float(fi.etype, val_encoding).fval
yield val

# TODO: Assert length of block was k+1? Messy unless block is len()able


def encode_block(
fi: BlockFormatInfo, scale: float, vals: Iterable[float]
) -> Iterable[int]:
"""
Encode a :paramref:`block` of bytes into block Format descibed by :paramref:`fi`
The :paramref:`scale` is explicitly passed, and is converted to `1/(1/scale)`
before rounding to the target format.
It is checked for overflow in the target format,
and will raise an exception if it does.
:param fi: Describes the target block format
:type fi: BlockFormatInfo
:param scale: Scale to be recorded in the block
:type scale: float
:param vals: Input block
:type vals: Iterable[int]
:return: A sequence of ints representing the encoded values.
:rtype: Iterable[int]
:raises ValueError: The scale overflows the target scale encoding format.
"""
recip_scale = 1 / scale
scale = 1 / recip_scale

if scale > fi.stype.max:
raise ValueError(f"Scaled {scale} too large for {fi.stype}")

enc = lambda ty, x: encode_float(ty, round_float(ty, x))

yield enc(fi.stype, scale)

for val in vals:
yield enc(fi.etype, recip_scale * val)
26 changes: 17 additions & 9 deletions src/gfloat/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,30 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue:
"""
k = fi.k
p = fi.precision
t = p - 1 # trailing significand field width
w = k - p
t = p - 1 # Trailing significand field width
num_signbits = 1 if fi.is_signed else 0
w = k - t - num_signbits # Exponent field width

if i < 0 or i >= 2**k:
raise ValueError(f"Code point {i} not in range [0, 2**{k})")

signmask = 1 << (k - 1)
signbit = 1 if i & signmask else 0
sign = -1 if signbit else 1
if fi.is_signed:
signmask = 1 << (k - 1)
signbit = 1 if i & signmask else 0
sign = -1 if signbit else 1
else:
signmask = None
signbit = 0
sign = 1

exp = (i & (signmask - 1)) >> t
exp = (i >> t) & ((1 << w) - 1)
significand = i & ((1 << t) - 1)
if fi.is_twos_complement and signbit:
significand = (1 << t) - significand

expBias = fi.expBias

iszero = exp == 0 and significand == 0
iszero = exp == 0 and significand == 0 and fi.has_zero
issubnormal = fi.has_subnormals and (exp == 0) and (significand != 0)
isnormal = not iszero and not issubnormal
if iszero or issubnormal:
Expand All @@ -56,15 +64,15 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue:

fval = val
# All-bits-special exponent (ABSE)
if exp == 2**w - 1:
if w > 0 and exp == 2**w - 1:
min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans
if significand >= min_i_with_nan:
fval = np.nan
if fi.has_infs and significand == min_i_with_nan - 1:
fval = signed_infinity

# Negative zero or NaN
if i == signmask:
if iszero and i == signmask and not fi.is_twos_complement:
if fi.has_nz:
fval = -0.0
else:
Expand Down
Loading

0 comments on commit 6c13686

Please sign in to comment.