Skip to content

Commit

Permalink
add floating point support to cider (#2335)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi authored Nov 6, 2024
1 parent 1e9e3b8 commit 42880f1
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 59 deletions.
24 changes: 24 additions & 0 deletions interp/src/flatten/flat_ir/cell_prototype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,30 @@ impl CellPrototype {
c_type: ConstantType::Primitive,
}
}
"std_float_const" => {
get_params![params;
value: "VALUE",
width: "WIDTH",
rep: "REP"
];

debug_assert_eq!(
rep, 0,
"Only supported floating point representation is IEEE."
);
debug_assert!(
width == 32 || width == 64,
"Only 32 and 64 bit floats are supported."
);

// we can treat floating point constants like any other constant since the
// frontend already converts the number to bits for us
Self::Constant {
value,
width: width.try_into().unwrap(),
c_type: ConstantType::Primitive,
}
}
n @ ("std_add" | "std_sadd") => {
get_params![params; width: "WIDTH"];

Expand Down
6 changes: 6 additions & 0 deletions interp/src/serialization/data_dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,18 @@ pub enum FormatInfo {
int_width: u32,
frac_width: u32,
},
IEEFloat {
signed: bool,
width: u32,
},
}

impl FormatInfo {
pub fn signed(&self) -> bool {
match self {
FormatInfo::Bitnum { signed, .. } => *signed,
FormatInfo::Fixed { signed, .. } => *signed,
FormatInfo::IEEFloat { signed, .. } => *signed,
}
}

Expand All @@ -73,6 +78,7 @@ impl FormatInfo {
frac_width,
..
} => *int_width + *frac_width,
FormatInfo::IEEFloat { width, .. } => *width,
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions interp/tests/runt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ fud2 --from calyx --to jq \
"""
timeout = 60

[[tests]]
name = "correctness ieee754-float"
paths = ["../../tests/correctness/ieee754-float/*.futil"]
cmd = """
fud2 --from calyx --to dat \
--through cider \
-s sim.data={}.data \
-s calyx.args="--log off" \
{} | jq --sort-keys
"""

[[tests]]
name = "correctness ref cells"
Expand Down
139 changes: 82 additions & 57 deletions tools/cider-data-converter/src/converter.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use super::json_data::*;
use interp::serialization::*;
use itertools::Itertools;
use num_bigint::{BigInt, BigUint, ToBigInt};
use num_rational::BigRational;
use num_traits::{sign::Signed, Num, ToPrimitive};
use serde_json::Number;
use std::{collections::HashMap, iter::repeat, str::FromStr};

use super::json_data::*;
use interp::serialization::*;

fn msb(width: u32) -> u8 {
let rem = width % 8;
1u8 << (if rem != 0 { rem - 1 } else { 7 }) // shift to the right by between 0 and 7
Expand Down Expand Up @@ -165,66 +164,75 @@ fn unroll_float(
val: f64,
format: &interp::serialization::FormatInfo,
round_float: bool,
) -> impl Iterator<Item = u8> {
if let &interp::serialization::FormatInfo::Fixed {
signed,
int_width,
frac_width,
} = format
{
let rational = float_to_rational(val);

let frac_part = rational.fract().abs();
let frac_log = log2_exact(&frac_part.denom().to_biguint().unwrap());

let number = if frac_log.is_none() && round_float {
let w = BigInt::from(1) << frac_width;
let new = (val * w.to_f64().unwrap()).round();
new.to_bigint().unwrap()
} else if frac_log.is_none() {
panic!("Number {val} cannot be represented as a fixed-point number. If you want to approximate the number, set the `round_float` flag to true.");
} else {
let int_part = rational.to_integer();

let frac_log = frac_log.unwrap_or_else(|| panic!("unable to round the given value to a value representable with {frac_width} fractional bits"));
if frac_log > frac_width {
panic!("cannot represent value with {frac_width} fractional bits, requires at least {frac_log} bits");
}
) -> Vec<u8> {
match *format {
interp::serialization::FormatInfo::Fixed {
signed,
int_width,
frac_width,
} =>
{
let rational = float_to_rational(val);

let frac_part = rational.fract().abs();
let frac_log = log2_exact(&frac_part.denom().to_biguint().unwrap());

let number = if frac_log.is_none() && round_float {
let w = BigInt::from(1) << frac_width;
let new = (val * w.to_f64().unwrap()).round();
new.to_bigint().unwrap()
} else if frac_log.is_none() {
panic!("Number {val} cannot be represented as a fixed-point number. If you want to approximate the number, set the `round_float` flag to true.");
} else {
let int_part = rational.to_integer();

let mut int_log =
log2_round_down(&int_part.abs().to_biguint().unwrap());
if (BigInt::from(1) << int_log) <= int_part.abs() {
int_log += 1;
}
if signed {
int_log += 1;
}
let frac_log = frac_log.unwrap_or_else(|| panic!("unable to round the given value to a value representable with {frac_width} fractional bits"));
if frac_log > frac_width {
panic!("cannot represent value with {frac_width} fractional bits, requires at least {frac_log} bits");
}

if int_log > int_width {
let signed_str = if signed { "signed " } else { "" };
let mut int_log =
log2_round_down(&int_part.abs().to_biguint().unwrap());
if (BigInt::from(1) << int_log) <= int_part.abs() {
int_log += 1;
}
if signed {
int_log += 1;
}

panic!("cannot represent {signed_str}value of {val} with {int_width} integer bits, requires at least {int_log} bits");
}
if int_log > int_width {
let signed_str = if signed { "signed " } else { "" };

rational.numer() << (frac_width - frac_log)
};
panic!("cannot represent {signed_str}value of {val} with {int_width} integer bits, requires at least {int_log} bits");
}

let bit_count = number.bits() + if signed { 1 } else { 0 };
rational.numer() << (frac_width - frac_log)
};

if bit_count > (frac_width + int_width) as u64 {
let difference = bit_count - frac_width as u64;
panic!("The approximation of the number {val} cannot be represented with {frac_width} fractional bits and {int_width} integer bits. Requires at least {difference} integer bits.");
}
let bit_count = number.bits() + if signed { 1 } else { 0 };

sign_extend_vec(
number.to_signed_bytes_le(),
frac_width + int_width,
signed,
)
.into_iter()
.take((frac_width + int_width).div_ceil(8) as usize)
} else {
panic!("Called unroll_float on a non-fixed point type");
if bit_count > (frac_width + int_width) as u64 {
let difference = bit_count - frac_width as u64;
panic!("The approximation of the number {val} cannot be represented with {frac_width} fractional bits and {int_width} integer bits. Requires at least {difference} integer bits.");
}

sign_extend_vec(
number.to_signed_bytes_le(),
frac_width + int_width,
signed,
)
.into_iter()
.take((frac_width + int_width).div_ceil(8) as usize)
.collect::<Vec<_>>()
}
interp::serialization::FormatInfo::IEEFloat { width, .. } => {
match width {
32 => Vec::from((val as f32).to_le_bytes().as_slice()),
64 => Vec::from(val.to_le_bytes().as_slice()),
_ => unreachable!("Unsupported width {width}. Only 32 and 64 bit floats are supported.")
}
}
_ => panic!("Called unroll_float on a non-fixed point type"),
}
}

Expand Down Expand Up @@ -281,6 +289,24 @@ fn format_data(declaration: &MemoryDeclaration, data: &[u8]) -> ParseVec {

Number::from_f64(float).unwrap()
}
interp::serialization::FormatInfo::IEEFloat {
width,
..
} => {
let value = match width {
32 => {
debug_assert_eq!(chunk.len(), 4);
format!("{}", f32::from_le_bytes(chunk.try_into().unwrap()))
}
64 => {
debug_assert_eq!(chunk.len(), 8);
format!("{}", f64::from_le_bytes(chunk.try_into().unwrap()))
}
_ => unreachable!("Unsupported width {width}. Only 32 and 64 bit floats are supported.")
};
// we need to inject the string directly in order to maintain the correct rounding
Number::from_string_unchecked(value)
}
}
});
// sanity check
Expand Down Expand Up @@ -392,7 +418,6 @@ mod tests {
};

let result = unroll_float(float, &format, true);
let result = result.collect_vec();
BigInt::from_signed_bytes_le(&result);
let parsed_res =
parse_bytes_fixed(&result, int_width, frac_width, signed);
Expand Down
16 changes: 14 additions & 2 deletions tools/cider-data-converter/src/json_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ use serde::{self, Deserialize, Serialize};
use serde_json::Number;
use thiserror::Error;

#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[derive(Debug, Serialize, Deserialize, Clone, Copy, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum NumericType {
Bitnum,
#[serde(alias = "fixed_point")]
Fixed,
#[serde(alias = "ieee754_float")]
IEEE754Float,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
Expand Down Expand Up @@ -46,6 +48,10 @@ impl FormatInfo {
|| self.width.is_some() && self.int_width.is_some()
}

pub fn is_floating_point(&self) -> bool {
self.numeric_type == NumericType::IEEE754Float
}

pub fn int_width(&self) -> Option<u32> {
if self.int_width.is_some() {
self.int_width
Expand Down Expand Up @@ -99,6 +105,12 @@ impl FormatInfo {
frac_width,
}
}
NumericType::IEEE754Float => {
interp::serialization::FormatInfo::IEEFloat {
signed: self.is_signed,
width: self.width.unwrap(),
}
}
}
}
}
Expand Down Expand Up @@ -227,7 +239,7 @@ impl ParseVec {
}

pub fn parse(&self, format: &FormatInfo) -> Result<DataVec, ParseError> {
if format.is_fixedpt() {
if format.is_fixedpt() || format.is_floating_point() {
match self {
ParseVec::D1(v) => {
let parsed: Vec<_> = v
Expand Down

0 comments on commit 42880f1

Please sign in to comment.