From 42880f16703fe5d85ad856dce9d8cbd4b83f40a9 Mon Sep 17 00:00:00 2001 From: Kevin Laeufer Date: Wed, 6 Nov 2024 09:44:05 -0500 Subject: [PATCH] add floating point support to cider (#2335) --- interp/src/flatten/flat_ir/cell_prototype.rs | 24 ++++ interp/src/serialization/data_dump.rs | 6 + interp/tests/runt.toml | 10 ++ tools/cider-data-converter/src/converter.rs | 139 +++++++++++-------- tools/cider-data-converter/src/json_data.rs | 16 ++- 5 files changed, 136 insertions(+), 59 deletions(-) diff --git a/interp/src/flatten/flat_ir/cell_prototype.rs b/interp/src/flatten/flat_ir/cell_prototype.rs index d7feea753..05acd4c3d 100644 --- a/interp/src/flatten/flat_ir/cell_prototype.rs +++ b/interp/src/flatten/flat_ir/cell_prototype.rs @@ -403,6 +403,30 @@ impl CellPrototype { c_type: ConstantType::Primitive, } } + "std_float_const" => { + get_params![params; + value: "VALUE", + width: "WIDTH", + rep: "REP" + ]; + + debug_assert_eq!( + rep, 0, + "Only supported floating point representation is IEEE." + ); + debug_assert!( + width == 32 || width == 64, + "Only 32 and 64 bit floats are supported." + ); + + // we can treat floating point constants like any other constant since the + // frontend already converts the number to bits for us + Self::Constant { + value, + width: width.try_into().unwrap(), + c_type: ConstantType::Primitive, + } + } n @ ("std_add" | "std_sadd") => { get_params![params; width: "WIDTH"]; diff --git a/interp/src/serialization/data_dump.rs b/interp/src/serialization/data_dump.rs index 5a95ffedf..0f386d75f 100644 --- a/interp/src/serialization/data_dump.rs +++ b/interp/src/serialization/data_dump.rs @@ -55,6 +55,10 @@ pub enum FormatInfo { int_width: u32, frac_width: u32, }, + IEEFloat { + signed: bool, + width: u32, + }, } impl FormatInfo { @@ -62,6 +66,7 @@ impl FormatInfo { match self { FormatInfo::Bitnum { signed, .. } => *signed, FormatInfo::Fixed { signed, .. } => *signed, + FormatInfo::IEEFloat { signed, .. } => *signed, } } @@ -73,6 +78,7 @@ impl FormatInfo { frac_width, .. } => *int_width + *frac_width, + FormatInfo::IEEFloat { width, .. } => *width, } } } diff --git a/interp/tests/runt.toml b/interp/tests/runt.toml index 298679eb6..f8aa1c397 100644 --- a/interp/tests/runt.toml +++ b/interp/tests/runt.toml @@ -153,6 +153,16 @@ fud2 --from calyx --to jq \ """ timeout = 60 +[[tests]] +name = "correctness ieee754-float" +paths = ["../../tests/correctness/ieee754-float/*.futil"] +cmd = """ +fud2 --from calyx --to dat \ + --through cider \ + -s sim.data={}.data \ + -s calyx.args="--log off" \ + {} | jq --sort-keys +""" [[tests]] name = "correctness ref cells" diff --git a/tools/cider-data-converter/src/converter.rs b/tools/cider-data-converter/src/converter.rs index 945b00fd0..5250c1825 100644 --- a/tools/cider-data-converter/src/converter.rs +++ b/tools/cider-data-converter/src/converter.rs @@ -1,3 +1,5 @@ +use super::json_data::*; +use interp::serialization::*; use itertools::Itertools; use num_bigint::{BigInt, BigUint, ToBigInt}; use num_rational::BigRational; @@ -5,9 +7,6 @@ use num_traits::{sign::Signed, Num, ToPrimitive}; use serde_json::Number; use std::{collections::HashMap, iter::repeat, str::FromStr}; -use super::json_data::*; -use interp::serialization::*; - fn msb(width: u32) -> u8 { let rem = width % 8; 1u8 << (if rem != 0 { rem - 1 } else { 7 }) // shift to the right by between 0 and 7 @@ -165,66 +164,75 @@ fn unroll_float( val: f64, format: &interp::serialization::FormatInfo, round_float: bool, -) -> impl Iterator { - if let &interp::serialization::FormatInfo::Fixed { - signed, - int_width, - frac_width, - } = format - { - let rational = float_to_rational(val); - - let frac_part = rational.fract().abs(); - let frac_log = log2_exact(&frac_part.denom().to_biguint().unwrap()); - - let number = if frac_log.is_none() && round_float { - let w = BigInt::from(1) << frac_width; - let new = (val * w.to_f64().unwrap()).round(); - new.to_bigint().unwrap() - } else if frac_log.is_none() { - panic!("Number {val} cannot be represented as a fixed-point number. If you want to approximate the number, set the `round_float` flag to true."); - } else { - let int_part = rational.to_integer(); - - let frac_log = frac_log.unwrap_or_else(|| panic!("unable to round the given value to a value representable with {frac_width} fractional bits")); - if frac_log > frac_width { - panic!("cannot represent value with {frac_width} fractional bits, requires at least {frac_log} bits"); - } +) -> Vec { + match *format { + interp::serialization::FormatInfo::Fixed { + signed, + int_width, + frac_width, + } => + { + let rational = float_to_rational(val); + + let frac_part = rational.fract().abs(); + let frac_log = log2_exact(&frac_part.denom().to_biguint().unwrap()); + + let number = if frac_log.is_none() && round_float { + let w = BigInt::from(1) << frac_width; + let new = (val * w.to_f64().unwrap()).round(); + new.to_bigint().unwrap() + } else if frac_log.is_none() { + panic!("Number {val} cannot be represented as a fixed-point number. If you want to approximate the number, set the `round_float` flag to true."); + } else { + let int_part = rational.to_integer(); - let mut int_log = - log2_round_down(&int_part.abs().to_biguint().unwrap()); - if (BigInt::from(1) << int_log) <= int_part.abs() { - int_log += 1; - } - if signed { - int_log += 1; - } + let frac_log = frac_log.unwrap_or_else(|| panic!("unable to round the given value to a value representable with {frac_width} fractional bits")); + if frac_log > frac_width { + panic!("cannot represent value with {frac_width} fractional bits, requires at least {frac_log} bits"); + } - if int_log > int_width { - let signed_str = if signed { "signed " } else { "" }; + let mut int_log = + log2_round_down(&int_part.abs().to_biguint().unwrap()); + if (BigInt::from(1) << int_log) <= int_part.abs() { + int_log += 1; + } + if signed { + int_log += 1; + } - panic!("cannot represent {signed_str}value of {val} with {int_width} integer bits, requires at least {int_log} bits"); - } + if int_log > int_width { + let signed_str = if signed { "signed " } else { "" }; - rational.numer() << (frac_width - frac_log) - }; + panic!("cannot represent {signed_str}value of {val} with {int_width} integer bits, requires at least {int_log} bits"); + } - let bit_count = number.bits() + if signed { 1 } else { 0 }; + rational.numer() << (frac_width - frac_log) + }; - if bit_count > (frac_width + int_width) as u64 { - let difference = bit_count - frac_width as u64; - panic!("The approximation of the number {val} cannot be represented with {frac_width} fractional bits and {int_width} integer bits. Requires at least {difference} integer bits."); - } + let bit_count = number.bits() + if signed { 1 } else { 0 }; - sign_extend_vec( - number.to_signed_bytes_le(), - frac_width + int_width, - signed, - ) - .into_iter() - .take((frac_width + int_width).div_ceil(8) as usize) - } else { - panic!("Called unroll_float on a non-fixed point type"); + if bit_count > (frac_width + int_width) as u64 { + let difference = bit_count - frac_width as u64; + panic!("The approximation of the number {val} cannot be represented with {frac_width} fractional bits and {int_width} integer bits. Requires at least {difference} integer bits."); + } + + sign_extend_vec( + number.to_signed_bytes_le(), + frac_width + int_width, + signed, + ) + .into_iter() + .take((frac_width + int_width).div_ceil(8) as usize) + .collect::>() + } + interp::serialization::FormatInfo::IEEFloat { width, .. } => { + match width { + 32 => Vec::from((val as f32).to_le_bytes().as_slice()), + 64 => Vec::from(val.to_le_bytes().as_slice()), + _ => unreachable!("Unsupported width {width}. Only 32 and 64 bit floats are supported.") + } + } + _ => panic!("Called unroll_float on a non-fixed point type"), } } @@ -281,6 +289,24 @@ fn format_data(declaration: &MemoryDeclaration, data: &[u8]) -> ParseVec { Number::from_f64(float).unwrap() } + interp::serialization::FormatInfo::IEEFloat { + width, + .. + } => { + let value = match width { + 32 => { + debug_assert_eq!(chunk.len(), 4); + format!("{}", f32::from_le_bytes(chunk.try_into().unwrap())) + } + 64 => { + debug_assert_eq!(chunk.len(), 8); + format!("{}", f64::from_le_bytes(chunk.try_into().unwrap())) + } + _ => unreachable!("Unsupported width {width}. Only 32 and 64 bit floats are supported.") + }; + // we need to inject the string directly in order to maintain the correct rounding + Number::from_string_unchecked(value) + } } }); // sanity check @@ -392,7 +418,6 @@ mod tests { }; let result = unroll_float(float, &format, true); - let result = result.collect_vec(); BigInt::from_signed_bytes_le(&result); let parsed_res = parse_bytes_fixed(&result, int_width, frac_width, signed); diff --git a/tools/cider-data-converter/src/json_data.rs b/tools/cider-data-converter/src/json_data.rs index 9ef5fc566..85eeccefd 100644 --- a/tools/cider-data-converter/src/json_data.rs +++ b/tools/cider-data-converter/src/json_data.rs @@ -6,12 +6,14 @@ use serde::{self, Deserialize, Serialize}; use serde_json::Number; use thiserror::Error; -#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[derive(Debug, Serialize, Deserialize, Clone, Copy, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum NumericType { Bitnum, #[serde(alias = "fixed_point")] Fixed, + #[serde(alias = "ieee754_float")] + IEEE754Float, } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -46,6 +48,10 @@ impl FormatInfo { || self.width.is_some() && self.int_width.is_some() } + pub fn is_floating_point(&self) -> bool { + self.numeric_type == NumericType::IEEE754Float + } + pub fn int_width(&self) -> Option { if self.int_width.is_some() { self.int_width @@ -99,6 +105,12 @@ impl FormatInfo { frac_width, } } + NumericType::IEEE754Float => { + interp::serialization::FormatInfo::IEEFloat { + signed: self.is_signed, + width: self.width.unwrap(), + } + } } } } @@ -227,7 +239,7 @@ impl ParseVec { } pub fn parse(&self, format: &FormatInfo) -> Result { - if format.is_fixedpt() { + if format.is_fixedpt() || format.is_floating_point() { match self { ParseVec::D1(v) => { let parsed: Vec<_> = v