diff --git a/doc/README.md b/doc/README.md index 2413aa04c..06534248e 100644 --- a/doc/README.md +++ b/doc/README.md @@ -31,9 +31,14 @@ Some in-development items will have opened issues, as well. Feel free to create - Sort - [Bitonic sort](./components/sort.md#bitonic-sort) - Arithmetic - - [Prefix Trees](./components/parallel_prefix_operations.md) + - [Prefix Trees](./components/parallel_prefix_operations.md) Several efficient components that leverage a variety of parallel prefix trees such as Ripple, Kogge-Stone, Sklansky, and Brent-Kung tree types. + - [Priority Encoder](./components/parallel_prefix_operations.md) + - [Or-scan](./components/parallel_prefix_operations.md) + - [Incrementer](./components/parallel_prefix_operations.md) + - [Decrementer](./components/parallel_prefix_operations.md) - [Adders](./components/adder.md) - [Sign Magnitude Adder](./components/adder.md#ripple-carry-adder) + - [Parallel Prefix Adder](./components/parallel_prefix_operations.md) - Subtractors - [One's Complement Adder Subtractor](./components/adder.md#ones-complement-adder-subtractor) - Multipliers @@ -47,11 +52,13 @@ Some in-development items will have opened issues, as well. Feel free to create - Square root - Inverse square root - Floating point + - [Floating-Point Value Types](./components/floating_point.md) - Double (64-bit) - Float (32-bit) - BFloat16 (16-bit) - BFloat8 (8-bit) - BFloat4 (4-bit) + - [Simple Floating-Point Adder](./components/floating_point.md#floatingpointadder) - Fixed point - Binary-Coded Decimal (BCD) - [Rotate](./components/rotate.md) diff --git a/doc/components/floating_point.md b/doc/components/floating_point.md new file mode 100644 index 000000000..8f525226c --- /dev/null +++ b/doc/components/floating_point.md @@ -0,0 +1,37 @@ +# Floating-Point Components + +Floating-point operations require meticulous precision, and have standards like [IEEE-754]() which govern them. To support floating-point components, we have created a parallel to `Logic`/`LogicValue` which are part of [ROHD](). Here, `FloatingPoint` is the `Logic` wire in a component that carries `FloatingPointValue` literal values. An important distinction is that these classes are parameterized to create arbitrary size floating-point values. + +## FloatingPointValue + +The `FloatingPointValue` class comprises the sign, exponent, and mantissa `LogicValue`s that represent a floating-point number. `FloatingPointValue`s can be converted to and from Dart native `Double`s, as well as constructed from integer and string representations of their fields. They can be operated on (+, -, *, /) and compared. + +A `FloatingPointValue` has a mantissa in $[0,2)$ with + +$$0 <= exponent <= maxExponent$$ + +A normal `isNormal` `FloatingPointValue` has: + +$$minExponent <= exponent <= maxExponent$$ + + And a mantissa in the range of $[1,2)$. Subnormal numbers are represented with a zero exponent and leading zeros in the mantissa capture the negative exponent value. + +The various IEEE constants representing corner cases of the field of floating-point values for a given size of `FloatingPointValue`: infinities, zeros, limits for normal (e.g. mantissa in the range of $[1,2])$ and sub-normal numbers (zero exponent, and mantissa <1). + +Appropriate string representations, comparison operations, and operators are available. The usefulness of `FloatingPointValue` is in the testing of `FloatingPoint` components, where we can leverage the abstraction of a floating-point value type to drive and compare floating-point values operated upon by floating-point components. + +As 32-bit single precision and 64-bit double-precision floating-point types are most common, we have `FloatingPoint32Value` and `FloatingPoint64Value` subclasses with direct converters from Dart native Double. + +Finally, we have a `FloatingPointValue` random generator for testing purposes, generating valid floating-point types, optionally constrained to normal range (mantissa in $[1, 2)$). + +## FloatingPoint + +The `FloatingPoint` type is a `LogicStructure` which comprises the `Logic` bits for the sign, exponent, and mantissa used in hardware floating-point. These types are provided to simplify and abstract the declaration and manipulation of floating-point types in hardware. This type is parameterized like `FloatingPointValue`, for exponent and mantissa width. + +Again, like `FloatingPointValue`, `FloatingPoint64` and `FloatingPoint32` subclasses are provided as these are the most common floating-point number types. + +## FloatingPointAdder + +A very basic `FloatingPointAdder` component is available which does not perform any rounding. It takes two `FloatingPoint` `LogicStructure`s and adds them, returning a normalized `FloatingPoint` on the output. An option on input is the type of `ParallelPrefixTree` used in the internal addition of the mantissas. + +Currently, the `FloatingPointAdder` is close in accuracy (as it has no rounding) and is not optimized for circuit performance, but only provides the key functionalities of alignment, addition, and normalization. Still, this component is a starting point for more realistic floating-point components that leverage the logical `FloatingPoint` and literal `FloatingPointValue` type abstractions. diff --git a/doc/components/multiplier.md b/doc/components/multiplier.md index e98ab5dac..14d2b8fea 100644 --- a/doc/components/multiplier.md +++ b/doc/components/multiplier.md @@ -109,7 +109,7 @@ Here is an example of use of the `CompressionTreeMultiplier`: ## Compression Tree Multiply Accumulate -A compression tree multiply accumulate is similar to a compress tree +A compression tree multiply-accumulate is similar to a compress tree multiplier, but it inserts an additional addend into the compression tree to allow for accumulation into this third input. diff --git a/doc/components/multiplier_components.md b/doc/components/multiplier_components.md index 2d084b6f8..af1b3e08e 100644 --- a/doc/components/multiplier_components.md +++ b/doc/components/multiplier_components.md @@ -103,7 +103,7 @@ The partial product generator produces a set of addends in shifted position to b An argument to the `PartialProductGenerator` is the `RadixEncoder` to be used. The [`RadixEncoder`] takes a single argument which is the radix (power of 2) to be used. -Instead of using the 1's in the multiplier to select shifted versions of the multiplicand to add in a partial product matrix, radix-encoding will encode multiples of the multiplicand by examining adjacent bits of the multiplier. For radix-4, for example, for a multiplier of size M, instead of M rows of partial products, M/2 rows are formed by selecting from multiples [-2, -1, 0, 1, 2] of the multiplicand. These multiples are computed from an 3 bit slices, overlapped by 1 bit, of the multiplier. Higher radices use wider slices of the multiplier to encode fewer multiples and therefore fewer rows. +Instead of using the 1's in the multiplier to select shifted versions of the multiplicand to add in a partial product matrix, radix-encoding will encode multiples of the multiplicand by examining adjacent bits of the multiplier. For radix-4, for example, for a multiplier of size M, instead of M rows of partial products, M/2 rows are formed by selecting from multiples [-2, -1, 0, 1, 2] of the multiplicand. These multiples are computed from an 3 bit slices, overlapped by 1 bit, of the multiplier. Higher radixes use wider slices of the multiplier to encode fewer multiples and therefore fewer rows. | bit_i | bit_i-1 | bit_i-2 | multiple| |:-----:|:-------:|:-------:|:-------:| @@ -199,3 +199,40 @@ Finally, we produce the product. compressor.exractRow(0), compressor.extractRow(1), BrentKung.new); product <= adder.sum.slice(a.width + b.width - 1, 0); ``` + +## Utility: Aligned Vector Formatting + +We provide an extension on `LogicValue` which permits formatting of binary vectors in an aligned way to help with debugging arithmetic components. + +The `vecString` extension provides a basic string printer with an optional `header` flag for bit numbering. A `prefix` value can be used to specify the name lengths to be used to keep vectors aligned. + +`alignHigh` controls the highest (toward MSB) alignment column of the output whereas `alignLow` controls the lower limit (toward the LSB). + +`sepPos' is optional and allows you to set a marker for a separator in the number. +`sepChar` is the separation character you wish to use (do not use '|' with Markdown formatting.) + +```dart + final ref = FloatingPoint64Value.fromDouble(3.14159); + print(ref.mantissa + .vecString('pi', alignHigh: 55, alignLow: 40, header: true, sepPos: 52)); +``` + +Produces + +```text + 54 53 52* 51 50 49 48 47 46 45 44 43 42 41 40 +pi * 1 0 0 1 0 0 1 0 0 0 0 1 +``` + +The routine also allows for output in Markdown format: + +```dart + print(ref.mantissa.vecString('pi', + alignHigh: 58, alignLow: 40, header: true, sepPos: 52, markDown: true)); +``` + +producing: + +| Name | 54 | 53 | 52* | 51 | 50 | 49 | 48 | 47 | 46 | 45 | 44 | 43 | 42 | 41 | 40 | +|:--:|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:--|:---| +|pi|||* | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | diff --git a/lib/src/arithmetic/arithmetic.dart b/lib/src/arithmetic/arithmetic.dart index bf52d17b1..549d843ec 100644 --- a/lib/src/arithmetic/arithmetic.dart +++ b/lib/src/arithmetic/arithmetic.dart @@ -4,6 +4,7 @@ export 'adder.dart'; export 'carry_save_mutiplier.dart'; export 'divider.dart'; +export 'floating_point/floating_point.dart'; export 'multiplier.dart'; export 'multiplier_lib.dart'; export 'ones_complement_adder.dart'; diff --git a/lib/src/arithmetic/arithmetic_utils.dart b/lib/src/arithmetic/arithmetic_utils.dart new file mode 100644 index 000000000..a49cec811 --- /dev/null +++ b/lib/src/arithmetic/arithmetic_utils.dart @@ -0,0 +1,118 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_test.dart +// Tests of Floating Point stuff +// +// 2024 August 30 +// Author: Desmond A Kirkpatrick = alignLow; col--) { + final chars = BigInt.from(col).toString().length + extraSpace; + if (sepPos != null && sepPos == col) { + str + ..write( + markDown ? ' $hdrSep' : ' ' * (length - chars + 1 + extraSpace)) + ..write('$col$sepChar') + ..write(markDown ? ' $hdrSep' : ''); + } else if (sepPos != null && sepPos == col + 1) { + if (sepPos == max(alignHigh ?? width, width)) { + str + ..write(sepChar) + ..write(markDown ? ' $hdrSep' : ' ' * (length - chars - 1)); + } + str.write('${' ' * (length - chars + extraSpace + 0)}$col'); + } else { + str + ..write( + markDown ? ' $hdrSep' : ' ' * (length - chars + 1 + extraSpace)) + ..write('$col'); + } + } + str.write(markDown ? ' $hdrSepEnd\n' : '\n'); + if (markDown) { + str.write(markDown ? '|:--:' : ' ' * prefix); + for (var col = highLimit; col >= alignLow; col--) { + str.write('|:--'); + } + str.write('-|\n'); + } + } + const dataSepStart = '|'; + const dataSep = '| '; + const dataSepEnd = '|'; + final String strPrefix; + strPrefix = markDown + ? '$dataSepStart $name' + : (name.length <= prefix) + ? name.padRight(prefix) + : name.substring(0, prefix); + str + ..write(strPrefix) + ..write((markDown ? dataSep : ' ' * (length + 1)) * + ((alignHigh ?? width) - width)); + for (var col = alignLow; col < minHigh; col++) { + final pos = minHigh - 1 - col + alignLow; + final v = this[pos].bitString; + if (sepPos != null && sepPos == pos) { + str.write( + markDown ? ' $dataSep$v $sepChar' : '${' ' * length}$v$sepChar'); + } else if (sepPos != null && sepPos == pos + 1) { + if (sepPos == minHigh) { + str.write(sepChar); + } + str + ..write(markDown ? ' $dataSep' : ' ' * (length - 1)) + ..write(v); + } else { + str + ..write(markDown ? ' $dataSep' : ' ' * length) + ..write(v); + } + } + if (markDown) { + str.write(' $dataSepEnd'); + } + return str.toString(); + } +} diff --git a/lib/src/arithmetic/evaluate_compressor.dart b/lib/src/arithmetic/evaluate_compressor.dart index 465743472..b1740d382 100644 --- a/lib/src/arithmetic/evaluate_compressor.dart +++ b/lib/src/arithmetic/evaluate_compressor.dart @@ -41,6 +41,7 @@ extension EvaluateLiveColumnCompressor on ColumnCompressor { } rowBits.addAll(List.filled(pp.rowShift[row], LogicValue.zero)); final val = rowBits.swizzle().zeroExtend(width).toBigInt(); + accum += val; if (printOut) { ts.write('\t${rowBits.swizzle().zeroExtend(width).bitString} ($val)'); @@ -52,10 +53,6 @@ extension EvaluateLiveColumnCompressor on ColumnCompressor { } } } - if (printOut) { - // We need this to be able to debug, but git lint flunks print - // print(ts); - } return (accum.toSigned(width), ts); } diff --git a/lib/src/arithmetic/evaluate_partial_product.dart b/lib/src/arithmetic/evaluate_partial_product.dart index e421d13d3..0fd425385 100644 --- a/lib/src/arithmetic/evaluate_partial_product.dart +++ b/lib/src/arithmetic/evaluate_partial_product.dart @@ -66,7 +66,7 @@ extension EvaluateLivePartialProduct on PartialProductGenerator { str.write(' ' * shortPrefix); } } else { - str.write('$rowStr ${'M='} S= : '); + str.write('$rowStr ${'M='} S= : '); } final entry = partialProducts[row].reversed.toList(); final prefixCnt = @@ -104,4 +104,73 @@ extension EvaluateLivePartialProduct on PartialProductGenerator { } return str.toString(); } + + /// Print out the partial product matrix + String markdown() { + final str = StringBuffer(); + + final maxW = maxWidth(); + // print bit position header + str.write('| R | M | S'); + for (var i = maxW - 1; i >= 0; i--) { + str.write('| $i '); + } + str + ..write('| bitvector | value|\n') + ..write('|:--:' * 3); + for (var i = maxW - 1; i >= 0; i--) { + str.write('|:--:'); + } + str.write('|:--: |:--:|\n'); + // Partial product matrix: rows of multiplicand multiples shift by + // rowshift[row] + for (var row = 0; row < rows; row++) { + final rowStr = (row < 10) ? '0$row' : '$row'; + if (row < encoder.rows) { + final encoding = encoder.getEncoding(row); + if (encoding.multiples.value.isValid) { + final first = encoding.multiples.value.firstOne() ?? -1; + final multiple = first + 1; + str.write('|$rowStr| ' + '$multiple| ' + '${encoding.sign.value.toInt()}'); + } else { + str.write('| | |'); + } + } else { + str.write('|$rowStr | |'); + } + final entry = partialProducts[row].reversed.toList(); + str.write('| ' * (maxW - (entry.length + rowShift[row]))); + for (var col = 0; col < entry.length; col++) { + str.write('|${entry[col].value.bitString}'); + } + final suffixCnt = rowShift[row]; + final value = entry.swizzle().value.zeroExtend(maxW) << suffixCnt; + final intValue = value.isValid ? value.toBigInt() : BigInt.from(-1); + str + ..write('| ' * suffixCnt) + ..write('| ${value.bitString}') + ..write('| ${value.isValid ? intValue : ""}' + ' (${value.isValid ? intValue.toSigned(maxW) : ""})|\n'); + } + // Compute and print binary representation from accumulated value + // Later: we will compare with a compression tree result + str.write('||\n'); + + final sum = LogicValue.ofBigInt(evaluate(), maxW); + // print out the sum as a MSB-first bitvector + str.write('|||'); + for (final elem in [for (var i = 0; i < maxW; i++) sum[i]].reversed) { + str.write('|${elem.toInt()} '); + } + final val = evaluate(); + str.write('| ${sum.bitString}| ' + '${val.toUnsigned(maxW)}'); + if (isSignExtended) { + str.write(' ($val)'); + } + str.write('|\n'); + return str.toString(); + } } diff --git a/lib/src/arithmetic/floating_point/floating_point.dart b/lib/src/arithmetic/floating_point/floating_point.dart new file mode 100644 index 000000000..231569572 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point.dart @@ -0,0 +1,6 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +export 'floating_point_adder.dart'; +export 'floating_point_logic.dart'; +export 'floating_point_value.dart'; diff --git a/lib/src/arithmetic/floating_point/floating_point_adder.dart b/lib/src/arithmetic/floating_point/floating_point_adder.dart new file mode 100644 index 000000000..e6e3f1d20 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_adder.dart @@ -0,0 +1,107 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_adder.dart +// A very basic Floating-point adder component. +// +// 2024 August 30 +// Author: Desmond A Kirkpatrick + ( + toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)), + toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2)) + ); + + /// Add two floating point numbers [a] and [b], returning result in [sum] + FloatingPointAdder(FloatingPoint a, FloatingPoint b, + {ParallelPrefix Function(List, Logic Function(Logic, Logic)) + ppGen = KoggeStone.new, + super.name}) + : exponentWidth = a.exponent.width, + mantissaWidth = a.mantissa.width { + if (b.exponent.width != exponentWidth || + b.mantissa.width != mantissaWidth) { + throw RohdHclException('FloatingPoint widths must match'); + } + a = a.clone()..gets(addInput('a', a, width: a.width)); + b = b.clone()..gets(addInput('b', b, width: b.width)); + addOutput('sum', width: _sum.width) <= _sum; + + // Ensure that the larger number is wired as 'a' + final doSwap = a.exponent.lt(b.exponent) | + (a.exponent.eq(b.exponent) & a.mantissa.lt(b.mantissa)) | + ((a.exponent.eq(b.exponent) & a.mantissa.eq(b.mantissa)) & b.sign); + + (a, b) = _swap(doSwap, (a, b)); + + final aExp = + a.exponent + mux(a.isNormal(), a.zeroExponent(), a.oneExponent()); + final bExp = + b.exponent + mux(b.isNormal(), b.zeroExponent(), b.oneExponent()); + + // Align and add mantissas + final expDiff = aExp - bExp; + // print('${expDiff.value.toInt()} exponent diff'); + final adder = SignMagnitudeAdder( + a.sign, + [a.isNormal(), a.mantissa].swizzle(), + b.sign, + [b.isNormal(), b.mantissa].swizzle() >>> expDiff, + (a, b) => ParallelPrefixAdder(a, b, ppGen: ppGen)); + + final sum = adder.sum.slice(adder.sum.width - 2, 0); + final leadOneE = + ParallelPrefixPriorityEncoder(sum.reversed, ppGen: ppGen).out; + final leadOne = leadOneE.zeroExtend(exponentWidth); + + // Assemble the output FloatingPoint + _sum.sign <= adder.sign; + Combinational([ + If.block([ + Iff(adder.sum[-1] & a.sign.eq(b.sign), [ + _sum.mantissa < (sum >> 1).slice(mantissaWidth - 1, 0), + _sum.exponent < a.exponent + 1 + ]), + ElseIf(a.exponent.gt(leadOne) & sum.or(), [ + _sum.mantissa < (sum << leadOne).slice(mantissaWidth - 1, 0), + _sum.exponent < a.exponent - leadOne + ]), + ElseIf(leadOne.eq(0) & sum.or(), [ + _sum.mantissa < (sum << leadOne).slice(mantissaWidth - 1, 0), + _sum.exponent < a.exponent - leadOne + 1 + ]), + Else([ + // subnormal result + _sum.mantissa < sum.slice(mantissaWidth - 1, 0), + _sum.exponent < _sum.zeroExponent() + ]) + ]) + ]); + // print('final sum: ${_sum.value.bitString}'); + } +} diff --git a/lib/src/arithmetic/floating_point/floating_point_logic.dart b/lib/src/arithmetic/floating_point/floating_point_logic.dart new file mode 100644 index 000000000..fcdca3107 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_logic.dart @@ -0,0 +1,102 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_logic.dart +// Implementation of Floating Point objects +// +// 2024 April 1 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick FloatingPoint( + exponentWidth: exponent.width, + mantissaWidth: mantissa.width, + ); + + /// Return the [FloatingPointValue] + FloatingPointValue get floatingPointValue => FloatingPointValue( + sign: sign.value, exponent: exponent.value, mantissa: mantissa.value); + + /// Return a Logic true if this FloatingPoint contains a normal number, + /// defined as having mantissa in the range [1,2) + Logic isNormal() => exponent.neq(LogicValue.zero.zeroExtend(exponent.width)); + + /// Return the zero exponent representation for this type of FloatingPoint + Logic zeroExponent() => Const(LogicValue.zero).zeroExtend(exponent.width); + + /// Return the one exponent representation for this type of FloatingPoint + Logic oneExponent() => Const(LogicValue.one).zeroExtend(exponent.width); + + @override + void put(dynamic val, {bool fill = false}) { + if (val is FloatingPointValue) { + put(val.value); + } else { + super.put(val, fill: fill); + } + } +} + +/// Single floating point representation +class FloatingPoint32 extends FloatingPoint { + /// Construct a 32-bit (single-precision) floating point number + FloatingPoint32() + : super( + exponentWidth: FloatingPoint32Value.exponentWidth, + mantissaWidth: FloatingPoint32Value.mantissaWidth); +} + +/// Double floating point representation +class FloatingPoint64 extends FloatingPoint { + /// Construct a 64-bit (double-precision) floating point number + FloatingPoint64() + : super( + exponentWidth: FloatingPoint64Value.exponentWidth, + mantissaWidth: FloatingPoint64Value.mantissaWidth); +} + +/// Eight-bit floating point representation for deep learning +class FloatingPoint8 extends FloatingPoint { + /// Calculate mantissa width and sanitize + static int _calculateMantissaWidth(int exponentWidth) { + final mantissaWidth = 7 - exponentWidth; + if (!FloatingPoint8Value.isLegal(exponentWidth, mantissaWidth)) { + throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); + } else { + return mantissaWidth; + } + } + + /// Construct an 8-bit floating point number + FloatingPoint8({required super.exponentWidth}) + : super(mantissaWidth: _calculateMantissaWidth(exponentWidth)); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_value.dart b/lib/src/arithmetic/floating_point/floating_point_value.dart new file mode 100644 index 000000000..59f8d74b9 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_value.dart @@ -0,0 +1,910 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_value.dart +// Implementation of Floating-Point value representations. +// +// 2024 April 1 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick { + /// The full floating point value bit storage + final LogicValue value; + + /// The sign of the value: 1 means a negative value + final LogicValue sign; + + /// The exponent of the floating point: this is biased about a midpoint for + /// positive and negative exponents + final LogicValue exponent; + + /// The mantissa of the floating point + final LogicValue mantissa; + + /// Return the exponent value representing the true zero exponent 2^0 = 1 + /// often termed [computeBias] or the offset of the exponent + static int computeBias(int exponentWidth) => + pow(2, exponentWidth - 1).toInt() - 1; + + /// Return the minimum exponent value + static int computeMinExponent(int exponentWidth) => + -pow(2, exponentWidth - 1).toInt() + 2; + + /// Return the maximum exponent value + static int computeMaxExponent(int exponentWidth) => + computeBias(exponentWidth); + + /// Return the bias of this [FloatingPointValue]. + int get bias => _bias; + + /// Return the maximum exponent of this [FloatingPointValue]. + int get maxExponent => _maxExp; + + /// Return the minimum exponent of this [FloatingPointValue]. + int get minExponent => _minExp; + + final int _bias; + final int _maxExp; + final int _minExp; + + /// Factory (static) constructor of a [FloatingPointValue] from + /// sign, mantissa and exponent + factory FloatingPointValue( + {required LogicValue sign, + required LogicValue exponent, + required LogicValue mantissa}) { + if (exponent.width == FloatingPoint32Value.exponentWidth && + mantissa.width == FloatingPoint32Value.mantissaWidth) { + return FloatingPoint32Value( + sign: sign, mantissa: mantissa, exponent: exponent); + } else if (exponent.width == FloatingPoint64Value._exponentWidth && + mantissa.width == FloatingPoint64Value._mantissaWidth) { + return FloatingPoint64Value( + sign: sign, mantissa: mantissa, exponent: exponent); + } else { + return FloatingPointValue.withConstraints( + sign: sign, mantissa: mantissa, exponent: exponent); + } + } + + /// [FloatingPointValue] constructor from a binary string representation of + /// individual bitfields + factory FloatingPointValue.ofBinaryStrings( + String sign, String exponent, String mantissa) { + if (sign.length != 1) { + throw RohdHclException('Sign string must be of length 1'); + } + + return FloatingPointValue( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); + } + + /// [FloatingPointValue] constructor from a single binary string representing + /// space-separated bitfields + factory FloatingPointValue.ofSeparatedBinaryStrings(String fp) { + final s = fp.split(' '); + if (s.length != 3) { + throw RohdHclException('FloatingPointValue requires three strings ' + 'to initialize'); + } + return FloatingPointValue.ofBinaryStrings(s[0], s[1], s[2]); + } + + /// [FloatingPointValue] constructor from a radix-encoded string + /// representation and the size of the exponent and mantissa + factory FloatingPointValue.ofString( + String fp, int exponentWidth, int mantissaWidth, + {int radix = 2}) { + final binaryFp = LogicValue.ofBigInt( + BigInt.parse(fp, radix: radix), exponentWidth + mantissaWidth + 1) + .bitString; + + final (sign, exponent, mantissa) = ( + binaryFp.substring(0, 1), + binaryFp.substring(1, 1 + exponentWidth), + binaryFp.substring(1 + exponentWidth, 1 + exponentWidth + mantissaWidth) + ); + return FloatingPointValue.ofBinaryStrings(sign, exponent, mantissa); + } + + /// [FloatingPointValue] constructor from a set of [BigInt]s of the binary + /// representation and the size of the exponent and mantissa + factory FloatingPointValue.ofBigInts(BigInt exponent, BigInt mantissa, + {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) { + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(exponent, exponentWidth), + LogicValue.ofBigInt(mantissa, mantissaWidth) + ); + + return FloatingPointValue( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } + + /// [FloatingPointValue] constructor from a set of [int]s of the binary + /// representation and the size of the exponent and mantissa + factory FloatingPointValue.ofInts(int exponent, int mantissa, + {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) { + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), + LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth) + ); + + return FloatingPointValue( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } + + /// Constructor enabling subclasses. + FloatingPointValue.withConstraints( + {required this.sign, + required this.exponent, + required this.mantissa, + int? mantissaWidth, + int? exponentWidth}) + : value = [sign, exponent, mantissa].swizzle(), + _bias = computeBias(exponent.width), + _minExp = computeMinExponent(exponent.width), + _maxExp = computeMaxExponent(exponent.width) { + if (sign.width != 1) { + throw RohdHclException('FloatingPointValue: sign width must be 1'); + } + if (mantissaWidth != null && mantissa.width != mantissaWidth) { + throw RohdHclException( + 'FloatingPointValue: mantissa width must be $mantissaWidth'); + } + if (exponentWidth != null && exponent.width != exponentWidth) { + throw RohdHclException( + 'FloatingPointValue: exponent width must be $exponentWidth'); + } + } + + /// Construct a [FloatingPointValue] from a Logic word + factory FloatingPointValue.fromLogic( + int exponentWidth, int mantissaWidth, LogicValue val) { + final sign = (val[-1] == LogicValue.one); + final exponent = + val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); + final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(exponent, exponentWidth), + LogicValue.ofBigInt(mantissa, mantissaWidth) + ); + return FloatingPointValue( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } + + /// Return the [FloatingPointValue] representing the constant specified + factory FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint, + int exponentWidth, + int mantissaWidth) { + switch (constantFloatingPoint) { + /// smallest possible number + case FloatingPointConstants.negativeInfinity: + return FloatingPointValue.ofBinaryStrings( + '1', '1' * exponentWidth, '0' * mantissaWidth); + + /// -0.0 + case FloatingPointConstants.negativeZero: + return FloatingPointValue.ofBinaryStrings( + '1', '0' * exponentWidth, '0' * mantissaWidth); + + /// 0.0 + case FloatingPointConstants.positiveZero: + return FloatingPointValue.ofBinaryStrings( + '0', '0' * exponentWidth, '0' * mantissaWidth); + + /// Smallest possible number, most exponent negative, LSB set in mantissa + case FloatingPointConstants.smallestPositiveSubnormal: + return FloatingPointValue.ofBinaryStrings( + '0', '0' * exponentWidth, '${'0' * (mantissaWidth - 1)}1'); + + /// Largest possible subnormal, most negative exponent, mantissa all 1s + case FloatingPointConstants.largestPositiveSubnormal: + return FloatingPointValue.ofBinaryStrings( + '0', '0' * exponentWidth, '1' * mantissaWidth); + + /// Smallest possible positive number, most negative exponent, mantissa 0 + case FloatingPointConstants.smallestPositiveNormal: + return FloatingPointValue.ofBinaryStrings( + '0', '${'0' * (exponentWidth - 1)}1', '0' * mantissaWidth); + + /// Largest number smaller than one + case FloatingPointConstants.largestLessThanOne: + return FloatingPointValue.ofBinaryStrings( + '0', '0${'1' * (exponentWidth - 2)}0', '1' * mantissaWidth); + + /// The number '1.0' + case FloatingPointConstants.one: + return FloatingPointValue.ofBinaryStrings( + '0', '0${'1' * (exponentWidth - 1)}', '0' * mantissaWidth); + + /// Smallest number greater than one + case FloatingPointConstants.smallestLargerThanOne: + return FloatingPointValue.ofBinaryStrings('0', + '0${'1' * (exponentWidth - 2)}0', '${'0' * (mantissaWidth - 1)}1'); + + /// Largest positive number, most positive exponent, full mantissa + case FloatingPointConstants.largestNormal: + return FloatingPointValue.ofBinaryStrings( + '0', '0' * exponentWidth, '1' * mantissaWidth); + + /// Largest possible number + case FloatingPointConstants.infinity: + return FloatingPointValue.ofBinaryStrings( + '0', '1' * exponentWidth, '0' * mantissaWidth); + } + } + + /// Convert from double using its native binary representation + factory FloatingPointValue.fromDouble(double inDouble, + {required int exponentWidth, + required int mantissaWidth, + FloatingPointRoundingMode roundingMode = + FloatingPointRoundingMode.roundNearestEven}) { + if ((exponentWidth == 8) && (mantissaWidth == 23)) { + return FloatingPoint32Value.fromDouble(inDouble); + } else if ((exponentWidth == 11) && (mantissaWidth == 52)) { + return FloatingPoint64Value.fromDouble(inDouble); + } + + final fp64 = FloatingPoint64Value.fromDouble(inDouble); + final exponent64 = fp64.exponent; + + var expVal = (exponent64.toInt() - fp64.bias) + + FloatingPointValue.computeBias(exponentWidth); + // Handle subnormal + final mantissa64 = [ + if (expVal <= 0) + ([LogicValue.one, fp64.mantissa].swizzle() >>> -expVal).slice(52, 1) + else + fp64.mantissa + ].first; + var mantissa = mantissa64.slice(51, 51 - mantissaWidth + 1); + + if (roundingMode == FloatingPointRoundingMode.roundNearestEven) { + final sticky = mantissa64.slice(51 - (mantissaWidth + 2), 0).or(); + final roundPos = 51 - (mantissaWidth + 2) + 1; + final round = mantissa64[roundPos]; + final guard = mantissa64[roundPos + 1]; + + // RNE Rounding + if (guard == LogicValue.one) { + if ((round == LogicValue.one) | + (sticky == LogicValue.one) | + (mantissa[0] == LogicValue.one)) { + mantissa += 1; + if (mantissa == LogicValue.zero.zeroExtend(mantissa.width)) { + expVal += 1; + } + } + } + } + + final exponent = + LogicValue.ofBigInt(BigInt.from(max(expVal, 0)), exponentWidth); + + return FloatingPointValue( + sign: fp64.sign, exponent: exponent, mantissa: mantissa); + } + + /// Generate a random [FloatingPointValue], supplying random seed [rv]. + /// This generates a valid [FloatingPointValue] anywhere in the range + /// it can represent:a general [FloatingPointValue] has + /// a mantissa in [0,2) with 0 <= exponent <= maxExponent(); + /// If [normal] is true, This routine will only generate mantissas in the + /// range of [1,2) and minExponent() <= exponent <= maxExponent(). + factory FloatingPointValue.random(Random rv, + {required int exponentWidth, + required int mantissaWidth, + bool normal = false}) { + final largestExponent = FloatingPointValue.computeBias(exponentWidth) + + FloatingPointValue.computeMaxExponent(exponentWidth); + final s = rv.nextLogicValue(width: 1).toInt(); + var e = BigInt.one; + do { + e = rv + .nextLogicValue(width: exponentWidth, max: largestExponent) + .toBigInt(); + } while ((e == BigInt.zero) & normal); + final m = rv.nextLogicValue(width: mantissaWidth).toBigInt(); + return FloatingPointValue( + sign: LogicValue.ofInt(s, 1), + exponent: LogicValue.ofBigInt(e, exponentWidth), + mantissa: LogicValue.ofBigInt(m, mantissaWidth)); + } + + /// Convert a floating point number into a [FloatingPointValue] + /// representation. This form performs NO ROUNDING. + factory FloatingPointValue.fromDoubleIter(double inDouble, + {required int exponentWidth, required int mantissaWidth}) { + if ((exponentWidth == 8) && (mantissaWidth == 23)) { + return FloatingPoint32Value.fromDouble(inDouble); + } else if ((exponentWidth == 11) && (mantissaWidth == 52)) { + return FloatingPoint64Value.fromDouble(inDouble); + } + + var doubleVal = inDouble; + if (inDouble.isNaN) { + return FloatingPointValue( + exponent: + LogicValue.ofInt(pow(2, exponentWidth).toInt() - 1, exponentWidth), + mantissa: LogicValue.zero, + sign: LogicValue.zero, + ); + } + LogicValue sign; + if (inDouble < 0.0) { + doubleVal = -doubleVal; + sign = LogicValue.one; + } else { + sign = LogicValue.zero; + } + + // If we are dealing with a really small number we need to scale it up + var scaleToWhole = (doubleVal != 0) ? (-log(doubleVal) / log(2)).ceil() : 0; + + if (doubleVal < 1.0) { + var myCnt = 0; + var myVal = doubleVal; + while (myVal % 1 != 0.0) { + myVal = myVal * 2.0; + myCnt++; + } + if (myCnt < scaleToWhole) { + scaleToWhole = myCnt; + } + } + + // Scale it up to go beyond the mantissa and include the GRS bits + final scale = mantissaWidth + scaleToWhole; + var s = scale; + + var sVal = doubleVal; + if (s > 0) { + while (s > 0) { + sVal *= 2.0; + s = s - 1; + } + } else { + sVal = doubleVal * pow(2.0, scale); + } + + final scaledValue = BigInt.from(sVal); + final fullLength = scaledValue.bitLength; + + var fullValue = LogicValue.ofBigInt(scaledValue, fullLength); + var e = (fullLength > 0) + ? fullLength - mantissaWidth - scaleToWhole + : FloatingPointValue.computeMinExponent(exponentWidth); + + if (e <= -FloatingPointValue.computeBias(exponentWidth)) { + fullValue = fullValue >>> + (scaleToWhole - FloatingPointValue.computeBias(exponentWidth)); + e = -FloatingPointValue.computeBias(exponentWidth); + } else { + // Could be just one away from subnormal + e -= 1; + if (e > -FloatingPointValue.computeBias(exponentWidth)) { + fullValue = fullValue << 1; // Chop the first '1' + } + } + // We reverse so that we fit into a shorter BigInt, we keep the MSB. + // The conversion fills leftward. + // We reverse again after conversion. + final exponent = LogicValue.ofInt( + e + FloatingPointValue.computeBias(exponentWidth), exponentWidth); + final mantissa = + LogicValue.ofBigInt(fullValue.reversed.toBigInt(), mantissaWidth) + .reversed; + + return FloatingPointValue( + exponent: exponent, + mantissa: mantissa, + sign: sign, + ); + } + + @override + int get hashCode => sign.hashCode ^ exponent.hashCode ^ mantissa.hashCode; + + /// Floating point comparison to implement Comparable<> + @override + int compareTo(Object other) { + if (other is! FloatingPointValue) { + throw Exception('Input must be of type FloatingPointValue '); + } + if ((exponent.width != other.exponent.width) | + (mantissa.width != other.mantissa.width)) { + throw Exception('FloatingPointValue widths must match for comparison'); + } + final signCompare = sign.compareTo(other.sign); + if (signCompare != 0) { + return signCompare; + } else { + final expCompare = exponent.compareTo(other.exponent); + if (expCompare != 0) { + return expCompare; + } else { + return mantissa.compareTo(other.mantissa); + } + } + } + + /// Return the bias of this FP format + // int bias() => FloatingPointValue.computeBias(exponent.width); + + @override + bool operator ==(Object other) { + if (other is! FloatingPointValue) { + return false; + } + + if ((exponent.width != other.exponent.width) | + (mantissa.width != other.mantissa.width)) { + return false; + } + + return (sign == other.sign) & + (exponent == other.exponent) & + (mantissa == other.mantissa); + } + + /// Return true if the represented floating point number is considered + /// NaN or 'Not a Number' due to overflow + // TODO(desmonddak): figure out the difference with Infinity + bool isNaN() { + if ((exponent.width == 4) & (mantissa.width == 3)) { + // FP8 E4M3 does not support infinities + final cond1 = (1 + exponent.toInt()) == pow(2, exponent.width).toInt(); + final cond2 = (1 + mantissa.toInt()) == pow(2, mantissa.width).toInt(); + return cond1 & cond2; + } else { + return exponent.toInt() == + computeMaxExponent(exponent.width) + computeBias(exponent.width) + 1; + } + } + + /// Return the value of the floating point number in a Dart [double] type. + double toDouble() { + var doubleVal = double.nan; + if (value.isValid) { + if (exponent.toInt() == 0) { + doubleVal = (sign.toBool() ? -1.0 : 1.0) * + pow(2.0, computeMinExponent(exponent.width)) * + mantissa.toBigInt().toDouble() / + pow(2.0, mantissa.width); + } else if (!isNaN()) { + doubleVal = (sign.toBool() ? -1.0 : 1.0) * + (1.0 + mantissa.toBigInt().toDouble() / pow(2.0, mantissa.width)) * + pow( + 2.0, + exponent.toInt().toSigned(exponent.width) - + computeBias(exponent.width)); + doubleVal = (sign.toBool() ? -1.0 : 1.0) * + (1.0 + mantissa.toBigInt().toDouble() / pow(2.0, mantissa.width)) * + pow(2.0, exponent.toInt() - computeBias(exponent.width)); + } + } + return doubleVal; + } + + /// Return a Logic true if this FloatingPointVa;ie contains a normal number, + /// defined as having mantissa in the range [1,2) + bool isNormal() => exponent != LogicValue.ofInt(0, exponent.width); + + @override + String toString() => '${sign.toString(includeWidth: false)}' + ' ${exponent.toString(includeWidth: false)}' + ' ${mantissa.toString(includeWidth: false)}'; + + // TODO(desmonddak): what about floating point representations >> 64 bits? + FloatingPointValue _performOp( + FloatingPointValue other, double Function(double a, double b) op) { + // make sure multiplicand has the same sizes as this + if (mantissa.width != other.mantissa.width || + exponent.width != other.exponent.width) { + throw RohdHclException('FloatingPointValue: ' + 'multiplicand must have the same mantissa and exponent widths'); + } + + return FloatingPointValue.fromDouble(op(toDouble(), other.toDouble()), + mantissaWidth: mantissa.width, exponentWidth: exponent.width); + } + + /// Multiply operation for [FloatingPointValue] + FloatingPointValue operator *(FloatingPointValue multiplicand) => + _performOp(multiplicand, (a, b) => a * b); + + /// Addition operation for [FloatingPointValue] + FloatingPointValue operator +(FloatingPointValue addend) => + _performOp(addend, (a, b) => a + b); + + /// Divide operation for [FloatingPointValue] + FloatingPointValue operator /(FloatingPointValue divisor) => + _performOp(divisor, (a, b) => a / b); + + /// Subtract operation for [FloatingPointValue] + FloatingPointValue operator -(FloatingPointValue subend) => + _performOp(subend, (a, b) => a - b); + + /// Negate operation for [FloatingPointValue] + FloatingPointValue negate() => FloatingPointValue( + sign: sign.isZero ? LogicValue.one : LogicValue.zero, + exponent: exponent, + mantissa: mantissa); + + /// Absolute value operation for [FloatingPointValue] + FloatingPointValue abs() => FloatingPointValue( + sign: LogicValue.zero, exponent: exponent, mantissa: mantissa); +} + +/// A representation of a single precision floating point value +class FloatingPoint32Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 23; + + /// Constructor for a single precision floating point value + FloatingPoint32Value( + {required super.sign, required super.exponent, required super.mantissa}) + : super.withConstraints( + mantissaWidth: mantissaWidth, exponentWidth: exponentWidth); + + /// Return the [FloatingPoint32Value] representing the constant specified + factory FloatingPoint32Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + as FloatingPoint32Value; + + /// [FloatingPoint32Value] constructor from string representation of + /// individual bitfields + factory FloatingPoint32Value.ofStrings( + String sign, String exponent, String mantissa) => + FloatingPoint32Value( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); + + /// [FloatingPoint32Value] constructor from a single string representing + /// space-separated bitfields + factory FloatingPoint32Value.ofString(String fp) { + final s = fp.split(' '); + assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); + return FloatingPoint32Value.ofStrings(s[0], s[1], s[2]); + } + + /// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary + /// representation + factory FloatingPoint32Value.ofBigInts(BigInt exponent, BigInt mantissa, + {bool sign = false}) { + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(exponent, exponentWidth), + LogicValue.ofBigInt(mantissa, mantissaWidth) + ); + + return FloatingPoint32Value( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } + + /// [FloatingPoint32Value] constructor from a set of [int]s of the binary + /// representation + factory FloatingPoint32Value.ofInts(int exponent, int mantissa, + {bool sign = false}) { + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), + LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth) + ); + + return FloatingPoint32Value( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } + + /// Numeric conversion of a [FloatingPoint32Value] from a host double + factory FloatingPoint32Value.fromDouble(double inDouble) { + final byteData = ByteData(4) + ..setFloat32(0, inDouble) + ..buffer.asUint8List(); + final bytes = byteData.buffer.asUint8List(); + final lv = bytes.map((b) => LogicValue.ofInt(b, 32)); + + final accum = lv.reduce((accum, v) => (accum << 8) | v); + + final sign = accum[-1]; + final exponent = + accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth); + final mantissa = accum.slice(mantissaWidth - 1, 0); + + return FloatingPoint32Value( + sign: sign, exponent: exponent, mantissa: mantissa); + } + + /// Construct a [FloatingPoint32Value] from a Logic word + factory FloatingPoint32Value.fromLogic(LogicValue val) { + final sign = (val[-1] == LogicValue.one); + final exponent = + val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth); + final mantissa = val.slice(mantissaWidth - 1, 0); + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + exponent, + mantissa + ); + return FloatingPoint32Value( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } +} + +/// A representation of a double precision floating point value +class FloatingPoint64Value extends FloatingPointValue { + static const int _exponentWidth = 11; + static const int _mantissaWidth = 52; + + /// return the exponent width + static int get exponentWidth => _exponentWidth; + + /// return the mantissa width + static int get mantissaWidth => _mantissaWidth; + + /// Constructor for a double precision floating point value + FloatingPoint64Value( + {required super.sign, required super.mantissa, required super.exponent}) + : super.withConstraints( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + /// Return the [FloatingPoint64Value] representing the constant specified + factory FloatingPoint64Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, _exponentWidth, _mantissaWidth) + as FloatingPoint64Value; + + /// [FloatingPoint64Value] constructor from string representation of + /// individual bitfields + factory FloatingPoint64Value.ofStrings( + String sign, String exponent, String mantissa) => + FloatingPoint64Value( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); + + /// [FloatingPoint64Value] constructor from a single string representing + /// space-separated bitfields + factory FloatingPoint64Value.ofString(String fp) { + final s = fp.split(' '); + assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); + return FloatingPoint64Value.ofStrings(s[0], s[1], s[2]); + } + + /// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary + /// representation + factory FloatingPoint64Value.ofBigInts(BigInt exponent, BigInt mantissa, + {bool sign = false}) => + FloatingPointValue.ofBigInts(exponent, mantissa, + sign: sign, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth) as FloatingPoint64Value; + + /// [FloatingPoint64Value] constructor from a set of [int]s of the binary + /// representation + factory FloatingPoint64Value.ofInts(int exponent, int mantissa, + {bool sign = false}) => + FloatingPointValue.ofInts(exponent, mantissa, + sign: sign, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth) as FloatingPoint64Value; + + /// Numeric conversion of a [FloatingPoint64Value] from a host double + factory FloatingPoint64Value.fromDouble(double inDouble) { + final byteData = ByteData(8) + ..setFloat64(0, inDouble) + ..buffer.asUint8List(); + final bytes = byteData.buffer.asUint8List(); + final lv = bytes.map((b) => LogicValue.ofInt(b, 64)); + + final accum = lv.reduce((accum, v) => (accum << 8) | v); + + final sign = accum[-1]; + final exponent = + accum.slice(_exponentWidth + _mantissaWidth - 1, _mantissaWidth); + final mantissa = accum.slice(_mantissaWidth - 1, 0); + + return FloatingPoint64Value( + sign: sign, mantissa: mantissa, exponent: exponent); + } + + /// Construct a [FloatingPoint32Value] from a Logic word + factory FloatingPoint64Value.fromLogic(LogicValue val) { + final sign = (val[-1] == LogicValue.one); + final exponent = + val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); + final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(exponent, exponentWidth), + LogicValue.ofBigInt(mantissa, mantissaWidth) + ); + return FloatingPoint64Value( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } +} + +/// A representation of a 8-bit floating point value as defined in +/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). +class FloatingPoint8Value extends FloatingPointValue { + /// The exponent width + late final int exponentWidth; + + /// The mantissa width + late final int mantissaWidth; + + static double get _e4m3max => 448.toDouble(); + static double get _e5m2max => 57344.toDouble(); + static double get _e4m3min => pow(2, -9).toDouble(); + static double get _e5m2min => pow(2, -16).toDouble(); + + /// Return if the exponent and mantissa widths match E4M3 or E5M2 + static bool isLegal(int exponentWidth, int mantissaWidth) { + if (((exponentWidth == 4) & (mantissaWidth == 3)) | + ((exponentWidth == 5) & (mantissaWidth == 2))) { + return true; + } else { + return false; + } + } + + /// Constructor for a double precision floating point value + FloatingPoint8Value( + {required super.sign, required super.mantissa, required super.exponent}) + : super.withConstraints() { + exponentWidth = exponent.width; + mantissaWidth = mantissa.width; + if (!isLegal(exponentWidth, mantissaWidth)) { + throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); + } + } + + /// [FloatingPoint8Value] constructor from string representation of + /// individual bitfields + factory FloatingPoint8Value.ofStrings( + String sign, String exponent, String mantissa) => + FloatingPoint8Value( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); + + /// [FloatingPoint8Value] constructor from a single string representing + /// space-separated bitfields + factory FloatingPoint8Value.ofString(String fp) { + final s = fp.split(' '); + assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); + return FloatingPoint8Value.ofStrings(s[0], s[1], s[2]); + } + + /// Construct a [FloatingPoint8Value] from a Logic word + factory FloatingPoint8Value.fromLogic(LogicValue val, int exponentWidth) { + if (val.width != 8) { + throw RohdHclException('Width must be 8'); + } + + final mantissaWidth = 7 - exponentWidth; + if (!isLegal(exponentWidth, mantissaWidth)) { + throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); + } + + final sign = (val[-1] == LogicValue.one); + final exponent = + val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); + final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); + final (signLv, exponentLv, mantissaLv) = ( + LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + LogicValue.ofBigInt(exponent, exponentWidth), + LogicValue.ofBigInt(mantissa, mantissaWidth) + ); + return FloatingPoint8Value( + sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + } + + /// Numeric conversion of a [FloatingPoint8Value] from a host double + factory FloatingPoint8Value.fromDouble(double inDouble, + {required int exponentWidth}) { + final mantissaWidth = 7 - exponentWidth; + if (!isLegal(exponentWidth, mantissaWidth)) { + throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); + } + if (exponentWidth == 4) { + if ((inDouble > _e4m3max) | (inDouble < _e4m3min)) { + throw RohdHclException('Number exceeds E4M3 range'); + } + } else if (exponentWidth == 5) { + if ((inDouble > _e5m2max) | (inDouble < _e5m2min)) { + throw RohdHclException('Number exceeds E5M2 range'); + } + } + final fpv = FloatingPointValue.fromDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPoint8Value( + sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); + } +} diff --git a/lib/src/arithmetic/multiplier_encoder.dart b/lib/src/arithmetic/multiplier_encoder.dart index cfe7727ef..2d5b0753a 100644 --- a/lib/src/arithmetic/multiplier_encoder.dart +++ b/lib/src/arithmetic/multiplier_encoder.dart @@ -131,64 +131,3 @@ class MultiplierEncoder { return _encoder.encode(multiplierSlice.first); } } - -/// A class accessing the multiples of the multiplicand at a position -class MultiplicandSelector { - /// radix of the selector - int radix; - - /// The bit shift of the selector (typically overlaps 1) - int shift; - - /// New width of partial products generated from the multiplicand - int get width => multiplicand.width + shift - 1; - - /// Access the multiplicand - Logic multiplicand = Logic(); - - /// Place to store multiples of the multiplicand - late LogicArray multiples; - - /// Generate required multiples of multiplicand - MultiplicandSelector(this.radix, this.multiplicand, {required bool signed}) - : shift = log2Ceil(radix) { - if (radix > 16) { - throw RohdHclException('Radices beyond 16 are not yet supported'); - } - final width = multiplicand.width + shift; - final numMultiples = radix ~/ 2; - multiples = LogicArray([numMultiples], width); - final extendedMultiplicand = signed - ? multiplicand.signExtend(width) - : multiplicand.zeroExtend(width); - - for (var pos = 0; pos < numMultiples; pos++) { - final ratio = pos + 1; - multiples.elements[pos] <= - switch (ratio) { - 1 => extendedMultiplicand, - 2 => extendedMultiplicand << 1, - 3 => (extendedMultiplicand << 2) - extendedMultiplicand, - 4 => extendedMultiplicand << 2, - 5 => (extendedMultiplicand << 2) + extendedMultiplicand, - 6 => (extendedMultiplicand << 3) - (extendedMultiplicand << 1), - 7 => (extendedMultiplicand << 3) - extendedMultiplicand, - 8 => extendedMultiplicand << 3, - _ => throw RohdHclException('Radix is beyond 16') - }; - } - } - - /// Retrieve the multiples of the multiplicand at current bit position - Logic getMultiples(int col) => [ - for (var i = 0; i < multiples.elements.length; i++) - multiples.elements[i][col] - ].swizzle().reversed; - - Logic _select(Logic multiples, RadixEncode encode) => - (encode.multiples & multiples).or() ^ encode.sign; - - /// Select the partial product term from the multiples using a RadixEncode - Logic select(int col, RadixEncode encode) => - _select(getMultiples(col), encode); -} diff --git a/lib/src/arithmetic/multiplier_lib.dart b/lib/src/arithmetic/multiplier_lib.dart index d37063194..eb6c8bcd2 100644 --- a/lib/src/arithmetic/multiplier_lib.dart +++ b/lib/src/arithmetic/multiplier_lib.dart @@ -11,5 +11,6 @@ // export './addend_compressor.dart'; +export './multiplicand_selector.dart'; export './multiplier_encoder.dart'; export './partial_product_generator.dart'; diff --git a/lib/src/arithmetic/partial_product_generator.dart b/lib/src/arithmetic/partial_product_generator.dart index 2f4efeedc..bbe305b61 100644 --- a/lib/src/arithmetic/partial_product_generator.dart +++ b/lib/src/arithmetic/partial_product_generator.dart @@ -112,7 +112,7 @@ class PartialProductGeneratorNoSignExtension extends PartialProductGenerator { void signExtend() {} } -/// A Partial Product Generator using Brute Sign Extension +/// A Partial Product Generator using Compact Rectangular Extension class PartialProductGeneratorCompactRectSignExtension extends PartialProductGenerator { /// Construct a compact rect sign extending Partial Product Generator diff --git a/test/arithmetic/floating_point/floating_point_adder_test.dart b/test/arithmetic/floating_point/floating_point_adder_test.dart new file mode 100644 index 000000000..3a9f69076 --- /dev/null +++ b/test/arithmetic/floating_point/floating_point_adder_test.dart @@ -0,0 +1,319 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_test.dart +// Tests of Floating Point stuff +// +// 2024 April 1 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick +// Desmond A Kirkpatrick