diff --git a/doc/README.md b/doc/README.md
index 2413aa04c..06534248e 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -31,9 +31,14 @@ Some in-development items will have opened issues, as well. Feel free to create
- Sort
- [Bitonic sort](./components/sort.md#bitonic-sort)
- Arithmetic
- - [Prefix Trees](./components/parallel_prefix_operations.md)
+ - [Prefix Trees](./components/parallel_prefix_operations.md) Several efficient components that leverage a variety of parallel prefix trees such as Ripple, Kogge-Stone, Sklansky, and Brent-Kung tree types.
+ - [Priority Encoder](./components/parallel_prefix_operations.md)
+ - [Or-scan](./components/parallel_prefix_operations.md)
+ - [Incrementer](./components/parallel_prefix_operations.md)
+ - [Decrementer](./components/parallel_prefix_operations.md)
- [Adders](./components/adder.md)
- [Sign Magnitude Adder](./components/adder.md#ripple-carry-adder)
+ - [Parallel Prefix Adder](./components/parallel_prefix_operations.md)
- Subtractors
- [One's Complement Adder Subtractor](./components/adder.md#ones-complement-adder-subtractor)
- Multipliers
@@ -47,11 +52,13 @@ Some in-development items will have opened issues, as well. Feel free to create
- Square root
- Inverse square root
- Floating point
+ - [Floating-Point Value Types](./components/floating_point.md)
- Double (64-bit)
- Float (32-bit)
- BFloat16 (16-bit)
- BFloat8 (8-bit)
- BFloat4 (4-bit)
+ - [Simple Floating-Point Adder](./components/floating_point.md#floatingpointadder)
- Fixed point
- Binary-Coded Decimal (BCD)
- [Rotate](./components/rotate.md)
diff --git a/doc/components/floating_point.md b/doc/components/floating_point.md
new file mode 100644
index 000000000..8f525226c
--- /dev/null
+++ b/doc/components/floating_point.md
@@ -0,0 +1,37 @@
+# Floating-Point Components
+
+Floating-point operations require meticulous precision, and have standards like [IEEE-754]() which govern them. To support floating-point components, we have created a parallel to `Logic`/`LogicValue` which are part of [ROHD](). Here, `FloatingPoint` is the `Logic` wire in a component that carries `FloatingPointValue` literal values. An important distinction is that these classes are parameterized to create arbitrary size floating-point values.
+
+## FloatingPointValue
+
+The `FloatingPointValue` class comprises the sign, exponent, and mantissa `LogicValue`s that represent a floating-point number. `FloatingPointValue`s can be converted to and from Dart native `Double`s, as well as constructed from integer and string representations of their fields. They can be operated on (+, -, *, /) and compared.
+
+A `FloatingPointValue` has a mantissa in $[0,2)$ with
+
+$$0 <= exponent <= maxExponent$$
+
+A normal `isNormal` `FloatingPointValue` has:
+
+$$minExponent <= exponent <= maxExponent$$
+
+ And a mantissa in the range of $[1,2)$. Subnormal numbers are represented with a zero exponent and leading zeros in the mantissa capture the negative exponent value.
+
+The various IEEE constants representing corner cases of the field of floating-point values for a given size of `FloatingPointValue`: infinities, zeros, limits for normal (e.g. mantissa in the range of $[1,2])$ and sub-normal numbers (zero exponent, and mantissa <1).
+
+Appropriate string representations, comparison operations, and operators are available. The usefulness of `FloatingPointValue` is in the testing of `FloatingPoint` components, where we can leverage the abstraction of a floating-point value type to drive and compare floating-point values operated upon by floating-point components.
+
+As 32-bit single precision and 64-bit double-precision floating-point types are most common, we have `FloatingPoint32Value` and `FloatingPoint64Value` subclasses with direct converters from Dart native Double.
+
+Finally, we have a `FloatingPointValue` random generator for testing purposes, generating valid floating-point types, optionally constrained to normal range (mantissa in $[1, 2)$).
+
+## FloatingPoint
+
+The `FloatingPoint` type is a `LogicStructure` which comprises the `Logic` bits for the sign, exponent, and mantissa used in hardware floating-point. These types are provided to simplify and abstract the declaration and manipulation of floating-point types in hardware. This type is parameterized like `FloatingPointValue`, for exponent and mantissa width.
+
+Again, like `FloatingPointValue`, `FloatingPoint64` and `FloatingPoint32` subclasses are provided as these are the most common floating-point number types.
+
+## FloatingPointAdder
+
+A very basic `FloatingPointAdder` component is available which does not perform any rounding. It takes two `FloatingPoint` `LogicStructure`s and adds them, returning a normalized `FloatingPoint` on the output. An option on input is the type of `ParallelPrefixTree` used in the internal addition of the mantissas.
+
+Currently, the `FloatingPointAdder` is close in accuracy (as it has no rounding) and is not optimized for circuit performance, but only provides the key functionalities of alignment, addition, and normalization. Still, this component is a starting point for more realistic floating-point components that leverage the logical `FloatingPoint` and literal `FloatingPointValue` type abstractions.
diff --git a/doc/components/multiplier.md b/doc/components/multiplier.md
index e98ab5dac..14d2b8fea 100644
--- a/doc/components/multiplier.md
+++ b/doc/components/multiplier.md
@@ -109,7 +109,7 @@ Here is an example of use of the `CompressionTreeMultiplier`:
## Compression Tree Multiply Accumulate
-A compression tree multiply accumulate is similar to a compress tree
+A compression tree multiply-accumulate is similar to a compress tree
multiplier, but it inserts an additional addend into the compression
tree to allow for accumulation into this third input.
diff --git a/doc/components/multiplier_components.md b/doc/components/multiplier_components.md
index 2d084b6f8..af1b3e08e 100644
--- a/doc/components/multiplier_components.md
+++ b/doc/components/multiplier_components.md
@@ -103,7 +103,7 @@ The partial product generator produces a set of addends in shifted position to b
An argument to the `PartialProductGenerator` is the `RadixEncoder` to be used. The [`RadixEncoder`] takes a single argument which is the radix (power of 2) to be used.
-Instead of using the 1's in the multiplier to select shifted versions of the multiplicand to add in a partial product matrix, radix-encoding will encode multiples of the multiplicand by examining adjacent bits of the multiplier. For radix-4, for example, for a multiplier of size M, instead of M rows of partial products, M/2 rows are formed by selecting from multiples [-2, -1, 0, 1, 2] of the multiplicand. These multiples are computed from an 3 bit slices, overlapped by 1 bit, of the multiplier. Higher radices use wider slices of the multiplier to encode fewer multiples and therefore fewer rows.
+Instead of using the 1's in the multiplier to select shifted versions of the multiplicand to add in a partial product matrix, radix-encoding will encode multiples of the multiplicand by examining adjacent bits of the multiplier. For radix-4, for example, for a multiplier of size M, instead of M rows of partial products, M/2 rows are formed by selecting from multiples [-2, -1, 0, 1, 2] of the multiplicand. These multiples are computed from an 3 bit slices, overlapped by 1 bit, of the multiplier. Higher radixes use wider slices of the multiplier to encode fewer multiples and therefore fewer rows.
| bit_i | bit_i-1 | bit_i-2 | multiple|
|:-----:|:-------:|:-------:|:-------:|
@@ -199,3 +199,40 @@ Finally, we produce the product.
compressor.exractRow(0), compressor.extractRow(1), BrentKung.new);
product <= adder.sum.slice(a.width + b.width - 1, 0);
```
+
+## Utility: Aligned Vector Formatting
+
+We provide an extension on `LogicValue` which permits formatting of binary vectors in an aligned way to help with debugging arithmetic components.
+
+The `vecString` extension provides a basic string printer with an optional `header` flag for bit numbering. A `prefix` value can be used to specify the name lengths to be used to keep vectors aligned.
+
+`alignHigh` controls the highest (toward MSB) alignment column of the output whereas `alignLow` controls the lower limit (toward the LSB).
+
+`sepPos' is optional and allows you to set a marker for a separator in the number.
+`sepChar` is the separation character you wish to use (do not use '|' with Markdown formatting.)
+
+```dart
+ final ref = FloatingPoint64Value.fromDouble(3.14159);
+ print(ref.mantissa
+ .vecString('pi', alignHigh: 55, alignLow: 40, header: true, sepPos: 52));
+```
+
+Produces
+
+```text
+ 54 53 52* 51 50 49 48 47 46 45 44 43 42 41 40
+pi * 1 0 0 1 0 0 1 0 0 0 0 1
+```
+
+The routine also allows for output in Markdown format:
+
+```dart
+ print(ref.mantissa.vecString('pi',
+ alignHigh: 58, alignLow: 40, header: true, sepPos: 52, markDown: true));
+```
+
+producing:
+
+| Name | 54 | 53 | 52* | 51 | 50 | 49 | 48 | 47 | 46 | 45 | 44 | 43 | 42 | 41 | 40 |
+|:--:|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:---|
+|pi|||* | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 |
diff --git a/lib/src/arithmetic/arithmetic.dart b/lib/src/arithmetic/arithmetic.dart
index bf52d17b1..549d843ec 100644
--- a/lib/src/arithmetic/arithmetic.dart
+++ b/lib/src/arithmetic/arithmetic.dart
@@ -4,6 +4,7 @@
export 'adder.dart';
export 'carry_save_mutiplier.dart';
export 'divider.dart';
+export 'floating_point/floating_point.dart';
export 'multiplier.dart';
export 'multiplier_lib.dart';
export 'ones_complement_adder.dart';
diff --git a/lib/src/arithmetic/arithmetic_utils.dart b/lib/src/arithmetic/arithmetic_utils.dart
new file mode 100644
index 000000000..a49cec811
--- /dev/null
+++ b/lib/src/arithmetic/arithmetic_utils.dart
@@ -0,0 +1,118 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: BSD-3-Clause
+//
+// floating_point_test.dart
+// Tests of Floating Point stuff
+//
+// 2024 August 30
+// Author: Desmond A Kirkpatrick = alignLow; col--) {
+ final chars = BigInt.from(col).toString().length + extraSpace;
+ if (sepPos != null && sepPos == col) {
+ str
+ ..write(
+ markDown ? ' $hdrSep' : ' ' * (length - chars + 1 + extraSpace))
+ ..write('$col$sepChar')
+ ..write(markDown ? ' $hdrSep' : '');
+ } else if (sepPos != null && sepPos == col + 1) {
+ if (sepPos == max(alignHigh ?? width, width)) {
+ str
+ ..write(sepChar)
+ ..write(markDown ? ' $hdrSep' : ' ' * (length - chars - 1));
+ }
+ str.write('${' ' * (length - chars + extraSpace + 0)}$col');
+ } else {
+ str
+ ..write(
+ markDown ? ' $hdrSep' : ' ' * (length - chars + 1 + extraSpace))
+ ..write('$col');
+ }
+ }
+ str.write(markDown ? ' $hdrSepEnd\n' : '\n');
+ if (markDown) {
+ str.write(markDown ? '|:--:' : ' ' * prefix);
+ for (var col = highLimit; col >= alignLow; col--) {
+ str.write('|:--');
+ }
+ str.write('-|\n');
+ }
+ }
+ const dataSepStart = '|';
+ const dataSep = '| ';
+ const dataSepEnd = '|';
+ final String strPrefix;
+ strPrefix = markDown
+ ? '$dataSepStart $name'
+ : (name.length <= prefix)
+ ? name.padRight(prefix)
+ : name.substring(0, prefix);
+ str
+ ..write(strPrefix)
+ ..write((markDown ? dataSep : ' ' * (length + 1)) *
+ ((alignHigh ?? width) - width));
+ for (var col = alignLow; col < minHigh; col++) {
+ final pos = minHigh - 1 - col + alignLow;
+ final v = this[pos].bitString;
+ if (sepPos != null && sepPos == pos) {
+ str.write(
+ markDown ? ' $dataSep$v $sepChar' : '${' ' * length}$v$sepChar');
+ } else if (sepPos != null && sepPos == pos + 1) {
+ if (sepPos == minHigh) {
+ str.write(sepChar);
+ }
+ str
+ ..write(markDown ? ' $dataSep' : ' ' * (length - 1))
+ ..write(v);
+ } else {
+ str
+ ..write(markDown ? ' $dataSep' : ' ' * length)
+ ..write(v);
+ }
+ }
+ if (markDown) {
+ str.write(' $dataSepEnd');
+ }
+ return str.toString();
+ }
+}
diff --git a/lib/src/arithmetic/evaluate_compressor.dart b/lib/src/arithmetic/evaluate_compressor.dart
index 465743472..b1740d382 100644
--- a/lib/src/arithmetic/evaluate_compressor.dart
+++ b/lib/src/arithmetic/evaluate_compressor.dart
@@ -41,6 +41,7 @@ extension EvaluateLiveColumnCompressor on ColumnCompressor {
}
rowBits.addAll(List.filled(pp.rowShift[row], LogicValue.zero));
final val = rowBits.swizzle().zeroExtend(width).toBigInt();
+
accum += val;
if (printOut) {
ts.write('\t${rowBits.swizzle().zeroExtend(width).bitString} ($val)');
@@ -52,10 +53,6 @@ extension EvaluateLiveColumnCompressor on ColumnCompressor {
}
}
}
- if (printOut) {
- // We need this to be able to debug, but git lint flunks print
- // print(ts);
- }
return (accum.toSigned(width), ts);
}
diff --git a/lib/src/arithmetic/evaluate_partial_product.dart b/lib/src/arithmetic/evaluate_partial_product.dart
index e421d13d3..0fd425385 100644
--- a/lib/src/arithmetic/evaluate_partial_product.dart
+++ b/lib/src/arithmetic/evaluate_partial_product.dart
@@ -66,7 +66,7 @@ extension EvaluateLivePartialProduct on PartialProductGenerator {
str.write(' ' * shortPrefix);
}
} else {
- str.write('$rowStr ${'M='} S= : ');
+ str.write('$rowStr ${'M='} S= : ');
}
final entry = partialProducts[row].reversed.toList();
final prefixCnt =
@@ -104,4 +104,73 @@ extension EvaluateLivePartialProduct on PartialProductGenerator {
}
return str.toString();
}
+
+ /// Print out the partial product matrix
+ String markdown() {
+ final str = StringBuffer();
+
+ final maxW = maxWidth();
+ // print bit position header
+ str.write('| R | M | S');
+ for (var i = maxW - 1; i >= 0; i--) {
+ str.write('| $i ');
+ }
+ str
+ ..write('| bitvector | value|\n')
+ ..write('|:--:' * 3);
+ for (var i = maxW - 1; i >= 0; i--) {
+ str.write('|:--:');
+ }
+ str.write('|:--: |:--:|\n');
+ // Partial product matrix: rows of multiplicand multiples shift by
+ // rowshift[row]
+ for (var row = 0; row < rows; row++) {
+ final rowStr = (row < 10) ? '0$row' : '$row';
+ if (row < encoder.rows) {
+ final encoding = encoder.getEncoding(row);
+ if (encoding.multiples.value.isValid) {
+ final first = encoding.multiples.value.firstOne() ?? -1;
+ final multiple = first + 1;
+ str.write('|$rowStr| '
+ '$multiple| '
+ '${encoding.sign.value.toInt()}');
+ } else {
+ str.write('| | |');
+ }
+ } else {
+ str.write('|$rowStr | |');
+ }
+ final entry = partialProducts[row].reversed.toList();
+ str.write('| ' * (maxW - (entry.length + rowShift[row])));
+ for (var col = 0; col < entry.length; col++) {
+ str.write('|${entry[col].value.bitString}');
+ }
+ final suffixCnt = rowShift[row];
+ final value = entry.swizzle().value.zeroExtend(maxW) << suffixCnt;
+ final intValue = value.isValid ? value.toBigInt() : BigInt.from(-1);
+ str
+ ..write('| ' * suffixCnt)
+ ..write('| ${value.bitString}')
+ ..write('| ${value.isValid ? intValue : ""}'
+ ' (${value.isValid ? intValue.toSigned(maxW) : ""})|\n');
+ }
+ // Compute and print binary representation from accumulated value
+ // Later: we will compare with a compression tree result
+ str.write('||\n');
+
+ final sum = LogicValue.ofBigInt(evaluate(), maxW);
+ // print out the sum as a MSB-first bitvector
+ str.write('|||');
+ for (final elem in [for (var i = 0; i < maxW; i++) sum[i]].reversed) {
+ str.write('|${elem.toInt()} ');
+ }
+ final val = evaluate();
+ str.write('| ${sum.bitString}| '
+ '${val.toUnsigned(maxW)}');
+ if (isSignExtended) {
+ str.write(' ($val)');
+ }
+ str.write('|\n');
+ return str.toString();
+ }
}
diff --git a/lib/src/arithmetic/floating_point/floating_point.dart b/lib/src/arithmetic/floating_point/floating_point.dart
new file mode 100644
index 000000000..231569572
--- /dev/null
+++ b/lib/src/arithmetic/floating_point/floating_point.dart
@@ -0,0 +1,6 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: BSD-3-Clause
+
+export 'floating_point_adder.dart';
+export 'floating_point_logic.dart';
+export 'floating_point_value.dart';
diff --git a/lib/src/arithmetic/floating_point/floating_point_adder.dart b/lib/src/arithmetic/floating_point/floating_point_adder.dart
new file mode 100644
index 000000000..e6e3f1d20
--- /dev/null
+++ b/lib/src/arithmetic/floating_point/floating_point_adder.dart
@@ -0,0 +1,107 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: BSD-3-Clause
+//
+// floating_point_adder.dart
+// A very basic Floating-point adder component.
+//
+// 2024 August 30
+// Author: Desmond A Kirkpatrick
+ (
+ toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)),
+ toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2))
+ );
+
+ /// Add two floating point numbers [a] and [b], returning result in [sum]
+ FloatingPointAdder(FloatingPoint a, FloatingPoint b,
+ {ParallelPrefix Function(List, Logic Function(Logic, Logic))
+ ppGen = KoggeStone.new,
+ super.name})
+ : exponentWidth = a.exponent.width,
+ mantissaWidth = a.mantissa.width {
+ if (b.exponent.width != exponentWidth ||
+ b.mantissa.width != mantissaWidth) {
+ throw RohdHclException('FloatingPoint widths must match');
+ }
+ a = a.clone()..gets(addInput('a', a, width: a.width));
+ b = b.clone()..gets(addInput('b', b, width: b.width));
+ addOutput('sum', width: _sum.width) <= _sum;
+
+ // Ensure that the larger number is wired as 'a'
+ final doSwap = a.exponent.lt(b.exponent) |
+ (a.exponent.eq(b.exponent) & a.mantissa.lt(b.mantissa)) |
+ ((a.exponent.eq(b.exponent) & a.mantissa.eq(b.mantissa)) & b.sign);
+
+ (a, b) = _swap(doSwap, (a, b));
+
+ final aExp =
+ a.exponent + mux(a.isNormal(), a.zeroExponent(), a.oneExponent());
+ final bExp =
+ b.exponent + mux(b.isNormal(), b.zeroExponent(), b.oneExponent());
+
+ // Align and add mantissas
+ final expDiff = aExp - bExp;
+ // print('${expDiff.value.toInt()} exponent diff');
+ final adder = SignMagnitudeAdder(
+ a.sign,
+ [a.isNormal(), a.mantissa].swizzle(),
+ b.sign,
+ [b.isNormal(), b.mantissa].swizzle() >>> expDiff,
+ (a, b) => ParallelPrefixAdder(a, b, ppGen: ppGen));
+
+ final sum = adder.sum.slice(adder.sum.width - 2, 0);
+ final leadOneE =
+ ParallelPrefixPriorityEncoder(sum.reversed, ppGen: ppGen).out;
+ final leadOne = leadOneE.zeroExtend(exponentWidth);
+
+ // Assemble the output FloatingPoint
+ _sum.sign <= adder.sign;
+ Combinational([
+ If.block([
+ Iff(adder.sum[-1] & a.sign.eq(b.sign), [
+ _sum.mantissa < (sum >> 1).slice(mantissaWidth - 1, 0),
+ _sum.exponent < a.exponent + 1
+ ]),
+ ElseIf(a.exponent.gt(leadOne) & sum.or(), [
+ _sum.mantissa < (sum << leadOne).slice(mantissaWidth - 1, 0),
+ _sum.exponent < a.exponent - leadOne
+ ]),
+ ElseIf(leadOne.eq(0) & sum.or(), [
+ _sum.mantissa < (sum << leadOne).slice(mantissaWidth - 1, 0),
+ _sum.exponent < a.exponent - leadOne + 1
+ ]),
+ Else([
+ // subnormal result
+ _sum.mantissa < sum.slice(mantissaWidth - 1, 0),
+ _sum.exponent < _sum.zeroExponent()
+ ])
+ ])
+ ]);
+ // print('final sum: ${_sum.value.bitString}');
+ }
+}
diff --git a/lib/src/arithmetic/floating_point/floating_point_logic.dart b/lib/src/arithmetic/floating_point/floating_point_logic.dart
new file mode 100644
index 000000000..fcdca3107
--- /dev/null
+++ b/lib/src/arithmetic/floating_point/floating_point_logic.dart
@@ -0,0 +1,102 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: BSD-3-Clause
+//
+// floating_point_logic.dart
+// Implementation of Floating Point objects
+//
+// 2024 April 1
+// Authors:
+// Max Korbel
+// Desmond A Kirkpatrick FloatingPoint(
+ exponentWidth: exponent.width,
+ mantissaWidth: mantissa.width,
+ );
+
+ /// Return the [FloatingPointValue]
+ FloatingPointValue get floatingPointValue => FloatingPointValue(
+ sign: sign.value, exponent: exponent.value, mantissa: mantissa.value);
+
+ /// Return a Logic true if this FloatingPoint contains a normal number,
+ /// defined as having mantissa in the range [1,2)
+ Logic isNormal() => exponent.neq(LogicValue.zero.zeroExtend(exponent.width));
+
+ /// Return the zero exponent representation for this type of FloatingPoint
+ Logic zeroExponent() => Const(LogicValue.zero).zeroExtend(exponent.width);
+
+ /// Return the one exponent representation for this type of FloatingPoint
+ Logic oneExponent() => Const(LogicValue.one).zeroExtend(exponent.width);
+
+ @override
+ void put(dynamic val, {bool fill = false}) {
+ if (val is FloatingPointValue) {
+ put(val.value);
+ } else {
+ super.put(val, fill: fill);
+ }
+ }
+}
+
+/// Single floating point representation
+class FloatingPoint32 extends FloatingPoint {
+ /// Construct a 32-bit (single-precision) floating point number
+ FloatingPoint32()
+ : super(
+ exponentWidth: FloatingPoint32Value.exponentWidth,
+ mantissaWidth: FloatingPoint32Value.mantissaWidth);
+}
+
+/// Double floating point representation
+class FloatingPoint64 extends FloatingPoint {
+ /// Construct a 64-bit (double-precision) floating point number
+ FloatingPoint64()
+ : super(
+ exponentWidth: FloatingPoint64Value.exponentWidth,
+ mantissaWidth: FloatingPoint64Value.mantissaWidth);
+}
+
+/// Eight-bit floating point representation for deep learning
+class FloatingPoint8 extends FloatingPoint {
+ /// Calculate mantissa width and sanitize
+ static int _calculateMantissaWidth(int exponentWidth) {
+ final mantissaWidth = 7 - exponentWidth;
+ if (!FloatingPoint8Value.isLegal(exponentWidth, mantissaWidth)) {
+ throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2');
+ } else {
+ return mantissaWidth;
+ }
+ }
+
+ /// Construct an 8-bit floating point number
+ FloatingPoint8({required super.exponentWidth})
+ : super(mantissaWidth: _calculateMantissaWidth(exponentWidth));
+}
diff --git a/lib/src/arithmetic/floating_point/floating_point_value.dart b/lib/src/arithmetic/floating_point/floating_point_value.dart
new file mode 100644
index 000000000..59f8d74b9
--- /dev/null
+++ b/lib/src/arithmetic/floating_point/floating_point_value.dart
@@ -0,0 +1,910 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: BSD-3-Clause
+//
+// floating_point_value.dart
+// Implementation of Floating-Point value representations.
+//
+// 2024 April 1
+// Authors:
+// Max Korbel
+// Desmond A Kirkpatrick {
+ /// The full floating point value bit storage
+ final LogicValue value;
+
+ /// The sign of the value: 1 means a negative value
+ final LogicValue sign;
+
+ /// The exponent of the floating point: this is biased about a midpoint for
+ /// positive and negative exponents
+ final LogicValue exponent;
+
+ /// The mantissa of the floating point
+ final LogicValue mantissa;
+
+ /// Return the exponent value representing the true zero exponent 2^0 = 1
+ /// often termed [computeBias] or the offset of the exponent
+ static int computeBias(int exponentWidth) =>
+ pow(2, exponentWidth - 1).toInt() - 1;
+
+ /// Return the minimum exponent value
+ static int computeMinExponent(int exponentWidth) =>
+ -pow(2, exponentWidth - 1).toInt() + 2;
+
+ /// Return the maximum exponent value
+ static int computeMaxExponent(int exponentWidth) =>
+ computeBias(exponentWidth);
+
+ /// Return the bias of this [FloatingPointValue].
+ int get bias => _bias;
+
+ /// Return the maximum exponent of this [FloatingPointValue].
+ int get maxExponent => _maxExp;
+
+ /// Return the minimum exponent of this [FloatingPointValue].
+ int get minExponent => _minExp;
+
+ final int _bias;
+ final int _maxExp;
+ final int _minExp;
+
+ /// Factory (static) constructor of a [FloatingPointValue] from
+ /// sign, mantissa and exponent
+ factory FloatingPointValue(
+ {required LogicValue sign,
+ required LogicValue exponent,
+ required LogicValue mantissa}) {
+ if (exponent.width == FloatingPoint32Value.exponentWidth &&
+ mantissa.width == FloatingPoint32Value.mantissaWidth) {
+ return FloatingPoint32Value(
+ sign: sign, mantissa: mantissa, exponent: exponent);
+ } else if (exponent.width == FloatingPoint64Value._exponentWidth &&
+ mantissa.width == FloatingPoint64Value._mantissaWidth) {
+ return FloatingPoint64Value(
+ sign: sign, mantissa: mantissa, exponent: exponent);
+ } else {
+ return FloatingPointValue.withConstraints(
+ sign: sign, mantissa: mantissa, exponent: exponent);
+ }
+ }
+
+ /// [FloatingPointValue] constructor from a binary string representation of
+ /// individual bitfields
+ factory FloatingPointValue.ofBinaryStrings(
+ String sign, String exponent, String mantissa) {
+ if (sign.length != 1) {
+ throw RohdHclException('Sign string must be of length 1');
+ }
+
+ return FloatingPointValue(
+ sign: LogicValue.of(sign),
+ exponent: LogicValue.of(exponent),
+ mantissa: LogicValue.of(mantissa));
+ }
+
+ /// [FloatingPointValue] constructor from a single binary string representing
+ /// space-separated bitfields
+ factory FloatingPointValue.ofSeparatedBinaryStrings(String fp) {
+ final s = fp.split(' ');
+ if (s.length != 3) {
+ throw RohdHclException('FloatingPointValue requires three strings '
+ 'to initialize');
+ }
+ return FloatingPointValue.ofBinaryStrings(s[0], s[1], s[2]);
+ }
+
+ /// [FloatingPointValue] constructor from a radix-encoded string
+ /// representation and the size of the exponent and mantissa
+ factory FloatingPointValue.ofString(
+ String fp, int exponentWidth, int mantissaWidth,
+ {int radix = 2}) {
+ final binaryFp = LogicValue.ofBigInt(
+ BigInt.parse(fp, radix: radix), exponentWidth + mantissaWidth + 1)
+ .bitString;
+
+ final (sign, exponent, mantissa) = (
+ binaryFp.substring(0, 1),
+ binaryFp.substring(1, 1 + exponentWidth),
+ binaryFp.substring(1 + exponentWidth, 1 + exponentWidth + mantissaWidth)
+ );
+ return FloatingPointValue.ofBinaryStrings(sign, exponent, mantissa);
+ }
+
+ /// [FloatingPointValue] constructor from a set of [BigInt]s of the binary
+ /// representation and the size of the exponent and mantissa
+ factory FloatingPointValue.ofBigInts(BigInt exponent, BigInt mantissa,
+ {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) {
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(exponent, exponentWidth),
+ LogicValue.ofBigInt(mantissa, mantissaWidth)
+ );
+
+ return FloatingPointValue(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+
+ /// [FloatingPointValue] constructor from a set of [int]s of the binary
+ /// representation and the size of the exponent and mantissa
+ factory FloatingPointValue.ofInts(int exponent, int mantissa,
+ {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) {
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth),
+ LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth)
+ );
+
+ return FloatingPointValue(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+
+ /// Constructor enabling subclasses.
+ FloatingPointValue.withConstraints(
+ {required this.sign,
+ required this.exponent,
+ required this.mantissa,
+ int? mantissaWidth,
+ int? exponentWidth})
+ : value = [sign, exponent, mantissa].swizzle(),
+ _bias = computeBias(exponent.width),
+ _minExp = computeMinExponent(exponent.width),
+ _maxExp = computeMaxExponent(exponent.width) {
+ if (sign.width != 1) {
+ throw RohdHclException('FloatingPointValue: sign width must be 1');
+ }
+ if (mantissaWidth != null && mantissa.width != mantissaWidth) {
+ throw RohdHclException(
+ 'FloatingPointValue: mantissa width must be $mantissaWidth');
+ }
+ if (exponentWidth != null && exponent.width != exponentWidth) {
+ throw RohdHclException(
+ 'FloatingPointValue: exponent width must be $exponentWidth');
+ }
+ }
+
+ /// Construct a [FloatingPointValue] from a Logic word
+ factory FloatingPointValue.fromLogic(
+ int exponentWidth, int mantissaWidth, LogicValue val) {
+ final sign = (val[-1] == LogicValue.one);
+ final exponent =
+ val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt();
+ final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt();
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(exponent, exponentWidth),
+ LogicValue.ofBigInt(mantissa, mantissaWidth)
+ );
+ return FloatingPointValue(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+
+ /// Return the [FloatingPointValue] representing the constant specified
+ factory FloatingPointValue.getFloatingPointConstant(
+ FloatingPointConstants constantFloatingPoint,
+ int exponentWidth,
+ int mantissaWidth) {
+ switch (constantFloatingPoint) {
+ /// smallest possible number
+ case FloatingPointConstants.negativeInfinity:
+ return FloatingPointValue.ofBinaryStrings(
+ '1', '1' * exponentWidth, '0' * mantissaWidth);
+
+ /// -0.0
+ case FloatingPointConstants.negativeZero:
+ return FloatingPointValue.ofBinaryStrings(
+ '1', '0' * exponentWidth, '0' * mantissaWidth);
+
+ /// 0.0
+ case FloatingPointConstants.positiveZero:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '0' * exponentWidth, '0' * mantissaWidth);
+
+ /// Smallest possible number, most exponent negative, LSB set in mantissa
+ case FloatingPointConstants.smallestPositiveSubnormal:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '0' * exponentWidth, '${'0' * (mantissaWidth - 1)}1');
+
+ /// Largest possible subnormal, most negative exponent, mantissa all 1s
+ case FloatingPointConstants.largestPositiveSubnormal:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '0' * exponentWidth, '1' * mantissaWidth);
+
+ /// Smallest possible positive number, most negative exponent, mantissa 0
+ case FloatingPointConstants.smallestPositiveNormal:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '${'0' * (exponentWidth - 1)}1', '0' * mantissaWidth);
+
+ /// Largest number smaller than one
+ case FloatingPointConstants.largestLessThanOne:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '0${'1' * (exponentWidth - 2)}0', '1' * mantissaWidth);
+
+ /// The number '1.0'
+ case FloatingPointConstants.one:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '0${'1' * (exponentWidth - 1)}', '0' * mantissaWidth);
+
+ /// Smallest number greater than one
+ case FloatingPointConstants.smallestLargerThanOne:
+ return FloatingPointValue.ofBinaryStrings('0',
+ '0${'1' * (exponentWidth - 2)}0', '${'0' * (mantissaWidth - 1)}1');
+
+ /// Largest positive number, most positive exponent, full mantissa
+ case FloatingPointConstants.largestNormal:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '0' * exponentWidth, '1' * mantissaWidth);
+
+ /// Largest possible number
+ case FloatingPointConstants.infinity:
+ return FloatingPointValue.ofBinaryStrings(
+ '0', '1' * exponentWidth, '0' * mantissaWidth);
+ }
+ }
+
+ /// Convert from double using its native binary representation
+ factory FloatingPointValue.fromDouble(double inDouble,
+ {required int exponentWidth,
+ required int mantissaWidth,
+ FloatingPointRoundingMode roundingMode =
+ FloatingPointRoundingMode.roundNearestEven}) {
+ if ((exponentWidth == 8) && (mantissaWidth == 23)) {
+ return FloatingPoint32Value.fromDouble(inDouble);
+ } else if ((exponentWidth == 11) && (mantissaWidth == 52)) {
+ return FloatingPoint64Value.fromDouble(inDouble);
+ }
+
+ final fp64 = FloatingPoint64Value.fromDouble(inDouble);
+ final exponent64 = fp64.exponent;
+
+ var expVal = (exponent64.toInt() - fp64.bias) +
+ FloatingPointValue.computeBias(exponentWidth);
+ // Handle subnormal
+ final mantissa64 = [
+ if (expVal <= 0)
+ ([LogicValue.one, fp64.mantissa].swizzle() >>> -expVal).slice(52, 1)
+ else
+ fp64.mantissa
+ ].first;
+ var mantissa = mantissa64.slice(51, 51 - mantissaWidth + 1);
+
+ if (roundingMode == FloatingPointRoundingMode.roundNearestEven) {
+ final sticky = mantissa64.slice(51 - (mantissaWidth + 2), 0).or();
+ final roundPos = 51 - (mantissaWidth + 2) + 1;
+ final round = mantissa64[roundPos];
+ final guard = mantissa64[roundPos + 1];
+
+ // RNE Rounding
+ if (guard == LogicValue.one) {
+ if ((round == LogicValue.one) |
+ (sticky == LogicValue.one) |
+ (mantissa[0] == LogicValue.one)) {
+ mantissa += 1;
+ if (mantissa == LogicValue.zero.zeroExtend(mantissa.width)) {
+ expVal += 1;
+ }
+ }
+ }
+ }
+
+ final exponent =
+ LogicValue.ofBigInt(BigInt.from(max(expVal, 0)), exponentWidth);
+
+ return FloatingPointValue(
+ sign: fp64.sign, exponent: exponent, mantissa: mantissa);
+ }
+
+ /// Generate a random [FloatingPointValue], supplying random seed [rv].
+ /// This generates a valid [FloatingPointValue] anywhere in the range
+ /// it can represent:a general [FloatingPointValue] has
+ /// a mantissa in [0,2) with 0 <= exponent <= maxExponent();
+ /// If [normal] is true, This routine will only generate mantissas in the
+ /// range of [1,2) and minExponent() <= exponent <= maxExponent().
+ factory FloatingPointValue.random(Random rv,
+ {required int exponentWidth,
+ required int mantissaWidth,
+ bool normal = false}) {
+ final largestExponent = FloatingPointValue.computeBias(exponentWidth) +
+ FloatingPointValue.computeMaxExponent(exponentWidth);
+ final s = rv.nextLogicValue(width: 1).toInt();
+ var e = BigInt.one;
+ do {
+ e = rv
+ .nextLogicValue(width: exponentWidth, max: largestExponent)
+ .toBigInt();
+ } while ((e == BigInt.zero) & normal);
+ final m = rv.nextLogicValue(width: mantissaWidth).toBigInt();
+ return FloatingPointValue(
+ sign: LogicValue.ofInt(s, 1),
+ exponent: LogicValue.ofBigInt(e, exponentWidth),
+ mantissa: LogicValue.ofBigInt(m, mantissaWidth));
+ }
+
+ /// Convert a floating point number into a [FloatingPointValue]
+ /// representation. This form performs NO ROUNDING.
+ factory FloatingPointValue.fromDoubleIter(double inDouble,
+ {required int exponentWidth, required int mantissaWidth}) {
+ if ((exponentWidth == 8) && (mantissaWidth == 23)) {
+ return FloatingPoint32Value.fromDouble(inDouble);
+ } else if ((exponentWidth == 11) && (mantissaWidth == 52)) {
+ return FloatingPoint64Value.fromDouble(inDouble);
+ }
+
+ var doubleVal = inDouble;
+ if (inDouble.isNaN) {
+ return FloatingPointValue(
+ exponent:
+ LogicValue.ofInt(pow(2, exponentWidth).toInt() - 1, exponentWidth),
+ mantissa: LogicValue.zero,
+ sign: LogicValue.zero,
+ );
+ }
+ LogicValue sign;
+ if (inDouble < 0.0) {
+ doubleVal = -doubleVal;
+ sign = LogicValue.one;
+ } else {
+ sign = LogicValue.zero;
+ }
+
+ // If we are dealing with a really small number we need to scale it up
+ var scaleToWhole = (doubleVal != 0) ? (-log(doubleVal) / log(2)).ceil() : 0;
+
+ if (doubleVal < 1.0) {
+ var myCnt = 0;
+ var myVal = doubleVal;
+ while (myVal % 1 != 0.0) {
+ myVal = myVal * 2.0;
+ myCnt++;
+ }
+ if (myCnt < scaleToWhole) {
+ scaleToWhole = myCnt;
+ }
+ }
+
+ // Scale it up to go beyond the mantissa and include the GRS bits
+ final scale = mantissaWidth + scaleToWhole;
+ var s = scale;
+
+ var sVal = doubleVal;
+ if (s > 0) {
+ while (s > 0) {
+ sVal *= 2.0;
+ s = s - 1;
+ }
+ } else {
+ sVal = doubleVal * pow(2.0, scale);
+ }
+
+ final scaledValue = BigInt.from(sVal);
+ final fullLength = scaledValue.bitLength;
+
+ var fullValue = LogicValue.ofBigInt(scaledValue, fullLength);
+ var e = (fullLength > 0)
+ ? fullLength - mantissaWidth - scaleToWhole
+ : FloatingPointValue.computeMinExponent(exponentWidth);
+
+ if (e <= -FloatingPointValue.computeBias(exponentWidth)) {
+ fullValue = fullValue >>>
+ (scaleToWhole - FloatingPointValue.computeBias(exponentWidth));
+ e = -FloatingPointValue.computeBias(exponentWidth);
+ } else {
+ // Could be just one away from subnormal
+ e -= 1;
+ if (e > -FloatingPointValue.computeBias(exponentWidth)) {
+ fullValue = fullValue << 1; // Chop the first '1'
+ }
+ }
+ // We reverse so that we fit into a shorter BigInt, we keep the MSB.
+ // The conversion fills leftward.
+ // We reverse again after conversion.
+ final exponent = LogicValue.ofInt(
+ e + FloatingPointValue.computeBias(exponentWidth), exponentWidth);
+ final mantissa =
+ LogicValue.ofBigInt(fullValue.reversed.toBigInt(), mantissaWidth)
+ .reversed;
+
+ return FloatingPointValue(
+ exponent: exponent,
+ mantissa: mantissa,
+ sign: sign,
+ );
+ }
+
+ @override
+ int get hashCode => sign.hashCode ^ exponent.hashCode ^ mantissa.hashCode;
+
+ /// Floating point comparison to implement Comparable<>
+ @override
+ int compareTo(Object other) {
+ if (other is! FloatingPointValue) {
+ throw Exception('Input must be of type FloatingPointValue ');
+ }
+ if ((exponent.width != other.exponent.width) |
+ (mantissa.width != other.mantissa.width)) {
+ throw Exception('FloatingPointValue widths must match for comparison');
+ }
+ final signCompare = sign.compareTo(other.sign);
+ if (signCompare != 0) {
+ return signCompare;
+ } else {
+ final expCompare = exponent.compareTo(other.exponent);
+ if (expCompare != 0) {
+ return expCompare;
+ } else {
+ return mantissa.compareTo(other.mantissa);
+ }
+ }
+ }
+
+ /// Return the bias of this FP format
+ // int bias() => FloatingPointValue.computeBias(exponent.width);
+
+ @override
+ bool operator ==(Object other) {
+ if (other is! FloatingPointValue) {
+ return false;
+ }
+
+ if ((exponent.width != other.exponent.width) |
+ (mantissa.width != other.mantissa.width)) {
+ return false;
+ }
+
+ return (sign == other.sign) &
+ (exponent == other.exponent) &
+ (mantissa == other.mantissa);
+ }
+
+ /// Return true if the represented floating point number is considered
+ /// NaN or 'Not a Number' due to overflow
+ // TODO(desmonddak): figure out the difference with Infinity
+ bool isNaN() {
+ if ((exponent.width == 4) & (mantissa.width == 3)) {
+ // FP8 E4M3 does not support infinities
+ final cond1 = (1 + exponent.toInt()) == pow(2, exponent.width).toInt();
+ final cond2 = (1 + mantissa.toInt()) == pow(2, mantissa.width).toInt();
+ return cond1 & cond2;
+ } else {
+ return exponent.toInt() ==
+ computeMaxExponent(exponent.width) + computeBias(exponent.width) + 1;
+ }
+ }
+
+ /// Return the value of the floating point number in a Dart [double] type.
+ double toDouble() {
+ var doubleVal = double.nan;
+ if (value.isValid) {
+ if (exponent.toInt() == 0) {
+ doubleVal = (sign.toBool() ? -1.0 : 1.0) *
+ pow(2.0, computeMinExponent(exponent.width)) *
+ mantissa.toBigInt().toDouble() /
+ pow(2.0, mantissa.width);
+ } else if (!isNaN()) {
+ doubleVal = (sign.toBool() ? -1.0 : 1.0) *
+ (1.0 + mantissa.toBigInt().toDouble() / pow(2.0, mantissa.width)) *
+ pow(
+ 2.0,
+ exponent.toInt().toSigned(exponent.width) -
+ computeBias(exponent.width));
+ doubleVal = (sign.toBool() ? -1.0 : 1.0) *
+ (1.0 + mantissa.toBigInt().toDouble() / pow(2.0, mantissa.width)) *
+ pow(2.0, exponent.toInt() - computeBias(exponent.width));
+ }
+ }
+ return doubleVal;
+ }
+
+ /// Return a Logic true if this FloatingPointVa;ie contains a normal number,
+ /// defined as having mantissa in the range [1,2)
+ bool isNormal() => exponent != LogicValue.ofInt(0, exponent.width);
+
+ @override
+ String toString() => '${sign.toString(includeWidth: false)}'
+ ' ${exponent.toString(includeWidth: false)}'
+ ' ${mantissa.toString(includeWidth: false)}';
+
+ // TODO(desmonddak): what about floating point representations >> 64 bits?
+ FloatingPointValue _performOp(
+ FloatingPointValue other, double Function(double a, double b) op) {
+ // make sure multiplicand has the same sizes as this
+ if (mantissa.width != other.mantissa.width ||
+ exponent.width != other.exponent.width) {
+ throw RohdHclException('FloatingPointValue: '
+ 'multiplicand must have the same mantissa and exponent widths');
+ }
+
+ return FloatingPointValue.fromDouble(op(toDouble(), other.toDouble()),
+ mantissaWidth: mantissa.width, exponentWidth: exponent.width);
+ }
+
+ /// Multiply operation for [FloatingPointValue]
+ FloatingPointValue operator *(FloatingPointValue multiplicand) =>
+ _performOp(multiplicand, (a, b) => a * b);
+
+ /// Addition operation for [FloatingPointValue]
+ FloatingPointValue operator +(FloatingPointValue addend) =>
+ _performOp(addend, (a, b) => a + b);
+
+ /// Divide operation for [FloatingPointValue]
+ FloatingPointValue operator /(FloatingPointValue divisor) =>
+ _performOp(divisor, (a, b) => a / b);
+
+ /// Subtract operation for [FloatingPointValue]
+ FloatingPointValue operator -(FloatingPointValue subend) =>
+ _performOp(subend, (a, b) => a - b);
+
+ /// Negate operation for [FloatingPointValue]
+ FloatingPointValue negate() => FloatingPointValue(
+ sign: sign.isZero ? LogicValue.one : LogicValue.zero,
+ exponent: exponent,
+ mantissa: mantissa);
+
+ /// Absolute value operation for [FloatingPointValue]
+ FloatingPointValue abs() => FloatingPointValue(
+ sign: LogicValue.zero, exponent: exponent, mantissa: mantissa);
+}
+
+/// A representation of a single precision floating point value
+class FloatingPoint32Value extends FloatingPointValue {
+ /// The exponent width
+ static const int exponentWidth = 8;
+
+ /// The mantissa width
+ static const int mantissaWidth = 23;
+
+ /// Constructor for a single precision floating point value
+ FloatingPoint32Value(
+ {required super.sign, required super.exponent, required super.mantissa})
+ : super.withConstraints(
+ mantissaWidth: mantissaWidth, exponentWidth: exponentWidth);
+
+ /// Return the [FloatingPoint32Value] representing the constant specified
+ factory FloatingPoint32Value.getFloatingPointConstant(
+ FloatingPointConstants constantFloatingPoint) =>
+ FloatingPointValue.getFloatingPointConstant(
+ constantFloatingPoint, exponentWidth, mantissaWidth)
+ as FloatingPoint32Value;
+
+ /// [FloatingPoint32Value] constructor from string representation of
+ /// individual bitfields
+ factory FloatingPoint32Value.ofStrings(
+ String sign, String exponent, String mantissa) =>
+ FloatingPoint32Value(
+ sign: LogicValue.of(sign),
+ exponent: LogicValue.of(exponent),
+ mantissa: LogicValue.of(mantissa));
+
+ /// [FloatingPoint32Value] constructor from a single string representing
+ /// space-separated bitfields
+ factory FloatingPoint32Value.ofString(String fp) {
+ final s = fp.split(' ');
+ assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}');
+ return FloatingPoint32Value.ofStrings(s[0], s[1], s[2]);
+ }
+
+ /// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary
+ /// representation
+ factory FloatingPoint32Value.ofBigInts(BigInt exponent, BigInt mantissa,
+ {bool sign = false}) {
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(exponent, exponentWidth),
+ LogicValue.ofBigInt(mantissa, mantissaWidth)
+ );
+
+ return FloatingPoint32Value(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+
+ /// [FloatingPoint32Value] constructor from a set of [int]s of the binary
+ /// representation
+ factory FloatingPoint32Value.ofInts(int exponent, int mantissa,
+ {bool sign = false}) {
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth),
+ LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth)
+ );
+
+ return FloatingPoint32Value(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+
+ /// Numeric conversion of a [FloatingPoint32Value] from a host double
+ factory FloatingPoint32Value.fromDouble(double inDouble) {
+ final byteData = ByteData(4)
+ ..setFloat32(0, inDouble)
+ ..buffer.asUint8List();
+ final bytes = byteData.buffer.asUint8List();
+ final lv = bytes.map((b) => LogicValue.ofInt(b, 32));
+
+ final accum = lv.reduce((accum, v) => (accum << 8) | v);
+
+ final sign = accum[-1];
+ final exponent =
+ accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth);
+ final mantissa = accum.slice(mantissaWidth - 1, 0);
+
+ return FloatingPoint32Value(
+ sign: sign, exponent: exponent, mantissa: mantissa);
+ }
+
+ /// Construct a [FloatingPoint32Value] from a Logic word
+ factory FloatingPoint32Value.fromLogic(LogicValue val) {
+ final sign = (val[-1] == LogicValue.one);
+ final exponent =
+ val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth);
+ final mantissa = val.slice(mantissaWidth - 1, 0);
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ exponent,
+ mantissa
+ );
+ return FloatingPoint32Value(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+}
+
+/// A representation of a double precision floating point value
+class FloatingPoint64Value extends FloatingPointValue {
+ static const int _exponentWidth = 11;
+ static const int _mantissaWidth = 52;
+
+ /// return the exponent width
+ static int get exponentWidth => _exponentWidth;
+
+ /// return the mantissa width
+ static int get mantissaWidth => _mantissaWidth;
+
+ /// Constructor for a double precision floating point value
+ FloatingPoint64Value(
+ {required super.sign, required super.mantissa, required super.exponent})
+ : super.withConstraints(
+ exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);
+
+ /// Return the [FloatingPoint64Value] representing the constant specified
+ factory FloatingPoint64Value.getFloatingPointConstant(
+ FloatingPointConstants constantFloatingPoint) =>
+ FloatingPointValue.getFloatingPointConstant(
+ constantFloatingPoint, _exponentWidth, _mantissaWidth)
+ as FloatingPoint64Value;
+
+ /// [FloatingPoint64Value] constructor from string representation of
+ /// individual bitfields
+ factory FloatingPoint64Value.ofStrings(
+ String sign, String exponent, String mantissa) =>
+ FloatingPoint64Value(
+ sign: LogicValue.of(sign),
+ exponent: LogicValue.of(exponent),
+ mantissa: LogicValue.of(mantissa));
+
+ /// [FloatingPoint64Value] constructor from a single string representing
+ /// space-separated bitfields
+ factory FloatingPoint64Value.ofString(String fp) {
+ final s = fp.split(' ');
+ assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}');
+ return FloatingPoint64Value.ofStrings(s[0], s[1], s[2]);
+ }
+
+ /// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary
+ /// representation
+ factory FloatingPoint64Value.ofBigInts(BigInt exponent, BigInt mantissa,
+ {bool sign = false}) =>
+ FloatingPointValue.ofBigInts(exponent, mantissa,
+ sign: sign,
+ exponentWidth: exponentWidth,
+ mantissaWidth: mantissaWidth) as FloatingPoint64Value;
+
+ /// [FloatingPoint64Value] constructor from a set of [int]s of the binary
+ /// representation
+ factory FloatingPoint64Value.ofInts(int exponent, int mantissa,
+ {bool sign = false}) =>
+ FloatingPointValue.ofInts(exponent, mantissa,
+ sign: sign,
+ exponentWidth: exponentWidth,
+ mantissaWidth: mantissaWidth) as FloatingPoint64Value;
+
+ /// Numeric conversion of a [FloatingPoint64Value] from a host double
+ factory FloatingPoint64Value.fromDouble(double inDouble) {
+ final byteData = ByteData(8)
+ ..setFloat64(0, inDouble)
+ ..buffer.asUint8List();
+ final bytes = byteData.buffer.asUint8List();
+ final lv = bytes.map((b) => LogicValue.ofInt(b, 64));
+
+ final accum = lv.reduce((accum, v) => (accum << 8) | v);
+
+ final sign = accum[-1];
+ final exponent =
+ accum.slice(_exponentWidth + _mantissaWidth - 1, _mantissaWidth);
+ final mantissa = accum.slice(_mantissaWidth - 1, 0);
+
+ return FloatingPoint64Value(
+ sign: sign, mantissa: mantissa, exponent: exponent);
+ }
+
+ /// Construct a [FloatingPoint32Value] from a Logic word
+ factory FloatingPoint64Value.fromLogic(LogicValue val) {
+ final sign = (val[-1] == LogicValue.one);
+ final exponent =
+ val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt();
+ final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt();
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(exponent, exponentWidth),
+ LogicValue.ofBigInt(mantissa, mantissaWidth)
+ );
+ return FloatingPoint64Value(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+}
+
+/// A representation of a 8-bit floating point value as defined in
+/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433).
+class FloatingPoint8Value extends FloatingPointValue {
+ /// The exponent width
+ late final int exponentWidth;
+
+ /// The mantissa width
+ late final int mantissaWidth;
+
+ static double get _e4m3max => 448.toDouble();
+ static double get _e5m2max => 57344.toDouble();
+ static double get _e4m3min => pow(2, -9).toDouble();
+ static double get _e5m2min => pow(2, -16).toDouble();
+
+ /// Return if the exponent and mantissa widths match E4M3 or E5M2
+ static bool isLegal(int exponentWidth, int mantissaWidth) {
+ if (((exponentWidth == 4) & (mantissaWidth == 3)) |
+ ((exponentWidth == 5) & (mantissaWidth == 2))) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ /// Constructor for a double precision floating point value
+ FloatingPoint8Value(
+ {required super.sign, required super.mantissa, required super.exponent})
+ : super.withConstraints() {
+ exponentWidth = exponent.width;
+ mantissaWidth = mantissa.width;
+ if (!isLegal(exponentWidth, mantissaWidth)) {
+ throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2');
+ }
+ }
+
+ /// [FloatingPoint8Value] constructor from string representation of
+ /// individual bitfields
+ factory FloatingPoint8Value.ofStrings(
+ String sign, String exponent, String mantissa) =>
+ FloatingPoint8Value(
+ sign: LogicValue.of(sign),
+ exponent: LogicValue.of(exponent),
+ mantissa: LogicValue.of(mantissa));
+
+ /// [FloatingPoint8Value] constructor from a single string representing
+ /// space-separated bitfields
+ factory FloatingPoint8Value.ofString(String fp) {
+ final s = fp.split(' ');
+ assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}');
+ return FloatingPoint8Value.ofStrings(s[0], s[1], s[2]);
+ }
+
+ /// Construct a [FloatingPoint8Value] from a Logic word
+ factory FloatingPoint8Value.fromLogic(LogicValue val, int exponentWidth) {
+ if (val.width != 8) {
+ throw RohdHclException('Width must be 8');
+ }
+
+ final mantissaWidth = 7 - exponentWidth;
+ if (!isLegal(exponentWidth, mantissaWidth)) {
+ throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2');
+ }
+
+ final sign = (val[-1] == LogicValue.one);
+ final exponent =
+ val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt();
+ final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt();
+ final (signLv, exponentLv, mantissaLv) = (
+ LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1),
+ LogicValue.ofBigInt(exponent, exponentWidth),
+ LogicValue.ofBigInt(mantissa, mantissaWidth)
+ );
+ return FloatingPoint8Value(
+ sign: signLv, exponent: exponentLv, mantissa: mantissaLv);
+ }
+
+ /// Numeric conversion of a [FloatingPoint8Value] from a host double
+ factory FloatingPoint8Value.fromDouble(double inDouble,
+ {required int exponentWidth}) {
+ final mantissaWidth = 7 - exponentWidth;
+ if (!isLegal(exponentWidth, mantissaWidth)) {
+ throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2');
+ }
+ if (exponentWidth == 4) {
+ if ((inDouble > _e4m3max) | (inDouble < _e4m3min)) {
+ throw RohdHclException('Number exceeds E4M3 range');
+ }
+ } else if (exponentWidth == 5) {
+ if ((inDouble > _e5m2max) | (inDouble < _e5m2min)) {
+ throw RohdHclException('Number exceeds E5M2 range');
+ }
+ }
+ final fpv = FloatingPointValue.fromDouble(inDouble,
+ exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);
+ return FloatingPoint8Value(
+ sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa);
+ }
+}
diff --git a/lib/src/arithmetic/multiplier_encoder.dart b/lib/src/arithmetic/multiplier_encoder.dart
index cfe7727ef..2d5b0753a 100644
--- a/lib/src/arithmetic/multiplier_encoder.dart
+++ b/lib/src/arithmetic/multiplier_encoder.dart
@@ -131,64 +131,3 @@ class MultiplierEncoder {
return _encoder.encode(multiplierSlice.first);
}
}
-
-/// A class accessing the multiples of the multiplicand at a position
-class MultiplicandSelector {
- /// radix of the selector
- int radix;
-
- /// The bit shift of the selector (typically overlaps 1)
- int shift;
-
- /// New width of partial products generated from the multiplicand
- int get width => multiplicand.width + shift - 1;
-
- /// Access the multiplicand
- Logic multiplicand = Logic();
-
- /// Place to store multiples of the multiplicand
- late LogicArray multiples;
-
- /// Generate required multiples of multiplicand
- MultiplicandSelector(this.radix, this.multiplicand, {required bool signed})
- : shift = log2Ceil(radix) {
- if (radix > 16) {
- throw RohdHclException('Radices beyond 16 are not yet supported');
- }
- final width = multiplicand.width + shift;
- final numMultiples = radix ~/ 2;
- multiples = LogicArray([numMultiples], width);
- final extendedMultiplicand = signed
- ? multiplicand.signExtend(width)
- : multiplicand.zeroExtend(width);
-
- for (var pos = 0; pos < numMultiples; pos++) {
- final ratio = pos + 1;
- multiples.elements[pos] <=
- switch (ratio) {
- 1 => extendedMultiplicand,
- 2 => extendedMultiplicand << 1,
- 3 => (extendedMultiplicand << 2) - extendedMultiplicand,
- 4 => extendedMultiplicand << 2,
- 5 => (extendedMultiplicand << 2) + extendedMultiplicand,
- 6 => (extendedMultiplicand << 3) - (extendedMultiplicand << 1),
- 7 => (extendedMultiplicand << 3) - extendedMultiplicand,
- 8 => extendedMultiplicand << 3,
- _ => throw RohdHclException('Radix is beyond 16')
- };
- }
- }
-
- /// Retrieve the multiples of the multiplicand at current bit position
- Logic getMultiples(int col) => [
- for (var i = 0; i < multiples.elements.length; i++)
- multiples.elements[i][col]
- ].swizzle().reversed;
-
- Logic _select(Logic multiples, RadixEncode encode) =>
- (encode.multiples & multiples).or() ^ encode.sign;
-
- /// Select the partial product term from the multiples using a RadixEncode
- Logic select(int col, RadixEncode encode) =>
- _select(getMultiples(col), encode);
-}
diff --git a/lib/src/arithmetic/multiplier_lib.dart b/lib/src/arithmetic/multiplier_lib.dart
index d37063194..eb6c8bcd2 100644
--- a/lib/src/arithmetic/multiplier_lib.dart
+++ b/lib/src/arithmetic/multiplier_lib.dart
@@ -11,5 +11,6 @@
//
export './addend_compressor.dart';
+export './multiplicand_selector.dart';
export './multiplier_encoder.dart';
export './partial_product_generator.dart';
diff --git a/lib/src/arithmetic/partial_product_generator.dart b/lib/src/arithmetic/partial_product_generator.dart
index 2f4efeedc..bbe305b61 100644
--- a/lib/src/arithmetic/partial_product_generator.dart
+++ b/lib/src/arithmetic/partial_product_generator.dart
@@ -112,7 +112,7 @@ class PartialProductGeneratorNoSignExtension extends PartialProductGenerator {
void signExtend() {}
}
-/// A Partial Product Generator using Brute Sign Extension
+/// A Partial Product Generator using Compact Rectangular Extension
class PartialProductGeneratorCompactRectSignExtension
extends PartialProductGenerator {
/// Construct a compact rect sign extending Partial Product Generator
diff --git a/test/arithmetic/floating_point/floating_point_adder_test.dart b/test/arithmetic/floating_point/floating_point_adder_test.dart
new file mode 100644
index 000000000..3a9f69076
--- /dev/null
+++ b/test/arithmetic/floating_point/floating_point_adder_test.dart
@@ -0,0 +1,319 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: BSD-3-Clause
+//
+// floating_point_test.dart
+// Tests of Floating Point stuff
+//
+// 2024 April 1
+// Authors:
+// Max Korbel
+// Desmond A Kirkpatrick
+// Desmond A Kirkpatrick