Skip to content

Commit

Permalink
Merge pull request #36 from graphcore-research/sr-bias-example
Browse files Browse the repository at this point in the history
Add comparison to "SRFast"
  • Loading branch information
awf authored Sep 19, 2024
2 parents ca391b8 + f8ca462 commit 0257255
Show file tree
Hide file tree
Showing 8 changed files with 1,170 additions and 80 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: "pip"

- name: Install requirements
Expand Down
1,019 changes: 994 additions & 25 deletions docs/source/05-stochastic-rounding.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
numpy
more_itertools
2 changes: 1 addition & 1 deletion src/gfloat/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def float_tilde_unless_roundtrip_str(v: float, width: int = 14, d: int = 8) -> s
# it is preceded by a "~" to indicate "approximately equal to"
s = f"{v}"
if len(s) > width:
if abs(v) < 1 and not "e" in s:
if abs(v) < 1 and "e" not in s:
s = f"{v:.{d}f}"
else:
s = f"{v:.{d}}"
Expand Down
47 changes: 34 additions & 13 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def round_float(
p = fi.precision
bias = fi.expBias

if rnd == RoundMode.Stochastic:
if rnd in (RoundMode.Stochastic, RoundMode.StochasticFast):
if srbits >= 2**srnumbits:
raise ValueError(f"srnumbits={srnumbits} >= 2**srnumbits={2**srnumbits}")

Expand Down Expand Up @@ -94,18 +94,39 @@ def round_float(
else (isignificand != 0 and _isodd(expval + bias))
)

if rnd == RoundMode.TowardZero:
should_round_away = False
if rnd == RoundMode.TowardPositive:
should_round_away = not sign and delta > 0
if rnd == RoundMode.TowardNegative:
should_round_away = sign and delta > 0
if rnd == RoundMode.TiesToAway:
should_round_away = delta >= 0.5
if rnd == RoundMode.TiesToEven:
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
if rnd == RoundMode.Stochastic:
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
match rnd:
case RoundMode.TowardZero:
should_round_away = False
case RoundMode.TowardPositive:
should_round_away = not sign and delta > 0
case RoundMode.TowardNegative:
should_round_away = sign and delta > 0
case RoundMode.TiesToAway:
should_round_away = delta >= 0.5
case RoundMode.TiesToEven:
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
case RoundMode.Stochastic:
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
d = floord + (
(d - floord > 0.5) or ((d - floord == 0.5) and _isodd(floord))
)

should_round_away = d > srbits
case RoundMode.StochasticOdd:
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
d = floord + (
(d - floord > 0.5) or ((d - floord == 0.5) and ~_isodd(floord))
)

should_round_away = d > srbits
case RoundMode.StochasticFast:
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
case RoundMode.StochasticFastest:
should_round_away = delta > srbits * 2.0**-srnumbits

if should_round_away:
# This may increase isignificand to 2**p,
Expand Down
75 changes: 53 additions & 22 deletions src/gfloat/round_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

from typing import Optional
from types import ModuleType
from .types import FormatInfo, RoundMode
import numpy as np
import math


def _isodd(v: np.ndarray) -> np.ndarray:
Expand All @@ -15,6 +15,8 @@ def round_ndarray(
v: np.ndarray,
rnd: RoundMode = RoundMode.TiesToEven,
sat: bool = False,
srbits: Optional[np.ndarray] = None,
srnumbits: int = 0,
np: ModuleType = np,
) -> np.ndarray:
"""
Expand All @@ -30,9 +32,12 @@ def round_ndarray(
Args:
fi (FormatInfo): Describes the target format
v (float): Input value to be rounded
v (float array): Input values to be rounded
rnd (RoundMode): Rounding mode to use
sat (bool): Saturation flag: if True, round overflowed values to `fi.max`
srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy
Returns:
Expand Down Expand Up @@ -70,18 +75,43 @@ def round_ndarray(
else:
code_is_odd = (isignificand != 0) & _isodd(expval + bias)

if rnd == RoundMode.TowardPositive:
round_up = ~is_negative & (delta > 0)
elif rnd == RoundMode.TowardNegative:
round_up = is_negative & (delta > 0)
elif rnd == RoundMode.TiesToAway:
round_up = delta >= 0.5
elif rnd == RoundMode.TiesToEven:
round_up = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
else:
round_up = np.zeros_like(delta, dtype=bool)

isignificand = np.where(round_up, isignificand + 1, isignificand)
match rnd:
case RoundMode.TowardZero:
should_round_away = np.zeros_like(delta, dtype=bool)
case RoundMode.TowardPositive:
should_round_away = ~is_negative & (delta > 0)
case RoundMode.TowardNegative:
should_round_away = is_negative & (delta > 0)
case RoundMode.TiesToAway:
should_round_away = delta >= 0.5
case RoundMode.TiesToEven:
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
case RoundMode.Stochastic:
assert srbits is not None
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
dd = d - floord
drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord))

should_round_away = drnd > srbits
case RoundMode.StochasticOdd:
assert srbits is not None
## RTNO delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
dd = d - floord
drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord))

should_round_away = drnd > srbits
case RoundMode.StochasticFast:
assert srbits is not None
should_round_away = delta > (2 * srbits + 1) * 2.0 ** -(1 + srnumbits)
case RoundMode.StochasticFastest:
assert srbits is not None
should_round_away = delta > srbits * 2.0**-srnumbits

isignificand = np.where(should_round_away, isignificand + 1, isignificand)

result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv)

Expand All @@ -90,14 +120,15 @@ def round_ndarray(
if sat:
result = np.where(result > amax, amax, result)
else:
if rnd == RoundMode.TowardNegative:
put_amax_at = (result > amax) & ~is_negative
elif rnd == RoundMode.TowardPositive:
put_amax_at = (result > amax) & is_negative
elif rnd == RoundMode.TowardZero:
put_amax_at = result > amax
else:
put_amax_at = np.zeros_like(result, dtype=bool)
match rnd:
case RoundMode.TowardNegative:
put_amax_at = (result > amax) & ~is_negative
case RoundMode.TowardPositive:
put_amax_at = (result > amax) & is_negative
case RoundMode.TowardZero:
put_amax_at = result > amax
case _:
put_amax_at = np.zeros_like(result, dtype=bool)

result = np.where(finite_nonzero & put_amax_at, amax, result)

Expand Down
15 changes: 15 additions & 0 deletions src/gfloat/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ class RoundMode(Enum):
TiesToEven = 4 #: Round to nearest, ties to even
TiesToAway = 5 #: Round to nearest, ties away from zero
Stochastic = 6 #: Stochastic rounding
StochasticFast = 7 #: Stochastic rounding - faster, but biased, see [Note 1].
StochasticFastest = 8 #: Stochastic rounding - incorrect, see [Note 1].
StochasticOdd = 9 #: Stochastic rounding, RTNO before comparison


# [Note 1]:
# StochasticFast implements a stochastic rounding scheme that is unbiased in
# infinite precision, but biased when the quantity to be rounded is computed to
# a finite precision.
#
# StochasticFastest implements a stochastic rounding scheme that is biased
# (the rounded value is on average farther from zero than the true value).
#
# With a lot of SRbits (say 8 or more), these biases are negligible, and there
# may be some efficiency advantage in using StochasticFast or StochasticFastest.


class FloatClass(Enum):
Expand Down
89 changes: 71 additions & 18 deletions test/test_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pytest

from gfloat import RoundMode, decode_float, round_float, round_ndarray
from gfloat import RoundMode, decode_float, decode_ndarray, round_float, round_ndarray
from gfloat.formats import *


Expand Down Expand Up @@ -428,7 +428,7 @@ def get_vals() -> Iterator[Tuple[float, float]]:
]


def _linterp(a: float, b: float, t: float) -> float:
def _linterp(a, b, t): # type: ignore[no-untyped-def]
return a * (1 - t) + b * t


Expand Down Expand Up @@ -494,13 +494,16 @@ def test_round_roundtrip(round_float: Callable, fi: FormatInfo) -> None:
"v, srnumbits, expected_up",
(
(259, 3, 0.0 / 8),
(259, 5, 1.0 / 32),
(259, 5, 2.0 / 32),
(277, 3, 3.0 / 8),
(288, 3, 0.5),
(311, 3, 7.0 / 8),
),
)
def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> None:
@pytest.mark.parametrize("impl", ("scalar", "array"))
def test_stochastic_rounding(
impl: bool, v: float, srnumbits: int, expected_up: float
) -> None:
fi = format_info_ocp_e5m2

v0 = round_float(fi, v, RoundMode.TowardNegative)
Expand All @@ -510,23 +513,73 @@ def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> No
expected_up_count = expected_up * n

srbits = np.random.randint(0, 2**srnumbits, size=(n,))
count_v1 = 0
for k in range(n):
r = round_float(
fi,
v,
RoundMode.Stochastic,
sat=False,
srbits=srbits[k],
srnumbits=srnumbits,
)
if r == v1:
count_v1 += 1
else:
assert r == v0
if impl == "scalar":
count_v1 = 0
for k in range(n):
r = round_float(
fi,
v,
RoundMode.Stochastic,
sat=False,
srbits=srbits[k],
srnumbits=srnumbits,
)
if r == v1:
count_v1 += 1
else:
assert r == v0
else:
vs = np.full(n, v)
rs = round_ndarray(fi, vs, RoundMode.Stochastic, False, srbits, srnumbits)
assert np.all((rs == v0) | (rs == v1))
count_v1 = np.sum(rs == v1)

print(f"SRBits={srnumbits}, observed = {count_v1}, expected = {expected_up_count} ")
# e.g. if expected is 1250/10000, want to be within 0.5,1.5
# this is loose, but should still catch logic errors
atol = n * 2.0 ** (-1 - srnumbits)
np.testing.assert_allclose(count_v1, expected_up_count, atol=atol)


@pytest.mark.parametrize(
"rnd",
(RoundMode.Stochastic, RoundMode.StochasticFast, RoundMode.StochasticFastest),
)
def test_stochastic_rounding_scalar_eq_array(rnd: RoundMode) -> None:
fi = format_info_p3109(3)

v0 = decode_ndarray(fi, np.arange(255))
v1 = decode_ndarray(fi, np.arange(255) + 1)
ok = np.isfinite(v0) & np.isfinite(v1)
v0 = v0[ok]
v1 = v1[ok]

srnumbits = 3
for srbits in range(2**srnumbits):
for alpha in (0, 0.3, 0.5, 0.6, 0.9, 1.25):
v = _linterp(v0, v1, alpha)
assert np.isfinite(v).all()
val_array = round_ndarray(
fi,
v,
rnd,
sat=False,
srbits=np.asarray(srbits),
srnumbits=srnumbits,
)

val_scalar = [
round_float(
fi,
v,
rnd,
sat=False,
srbits=srbits,
srnumbits=srnumbits,
)
for v in v
]
if alpha < 1.0:
assert ((val_array == v0) | (val_array == v1)).all()

np.testing.assert_equal(val_array, val_scalar)

0 comments on commit 0257255

Please sign in to comment.