Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numeric Casting Operations #361

Merged
merged 18 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ impl<'i> CoreParser<'i> {
pub fn parse_numb_sym(&mut self) -> Result<Numb, String> {
self.consume("[")?;

// numeric casts
if let Some(cast) = match () {
_ if self.try_consume("to_u24") => Some(hvm::TY_U24),
_ if self.try_consume("to_i24") => Some(hvm::TY_I24),
_ if self.try_consume("to_f24") => Some(hvm::TY_F24),
_ => None
} {
// Casts can't be partially applied, so nothing should follow.
self.consume("]")?;

return Ok(Numb(hvm::Numb::new_sym(cast).0));
}

// Parses the symbol
let op = hvm::Numb::new_sym(match () {
// numeric operations
_ if self.try_consume("+") => hvm::OP_ADD,
_ if self.try_consume("-") => hvm::OP_SUB,
_ if self.try_consume(":-") => hvm::FP_SUB,
Expand Down Expand Up @@ -224,6 +238,11 @@ impl Numb {
let numb = hvm::Numb(self.0);
match numb.get_typ() {
hvm::TY_SYM => match numb.get_sym() as hvm::Tag {
// casts
hvm::TY_U24 => "[to_u24]".to_string(),
hvm::TY_I24 => "[to_i24]".to_string(),
hvm::TY_F24 => "[to_f24]".to_string(),
// operations
hvm::OP_ADD => "[+]".to_string(),
hvm::OP_SUB => "[-]".to_string(),
hvm::FP_SUB => "[:-]".to_string(),
Expand Down Expand Up @@ -264,7 +283,7 @@ impl Numb {
} else if val.is_nan() {
format!("+NaN")
} else {
format!("{:?}", val)
enricozb marked this conversation as resolved.
Show resolved Hide resolved
format!("{:.7e}", val)
}
}
_ => {
Expand Down
69 changes: 67 additions & 2 deletions src/hvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ typedef uint16_t u16;
typedef int32_t i32;
typedef uint32_t u32;
typedef uint64_t u64;
typedef float f32;
typedef double f64;

typedef _Atomic(u8) a8;
typedef _Atomic(u16) a16;
Expand Down Expand Up @@ -278,10 +280,15 @@ static inline void swap(Port *a, Port *b) {
Port x = *a; *a = *b; *b = x;
}

u32 min(u32 a, u32 b) {
inline u32 min(u32 a, u32 b) {
return (a < b) ? a : b;
}

inline f32 clamp(f32 x, f32 min, f32 max) {
const f32 t = x < min ? min : x;
return (t > max) ? max : t;
}

// A simple spin-wait barrier using atomic operations
a64 a_reached = 0; // number of threads that reached the current barrier
a64 a_barrier = 0; // number of barriers passed during this program
Expand Down Expand Up @@ -429,18 +436,76 @@ static inline Tag get_typ(Numb word) {
return word & 0x1F;
}

static inline bool is_num(Numb word) {
return get_typ(word) >= TY_U24 && get_typ(word) <= TY_F24;
}

static inline bool is_cast(Numb word) {
return get_typ(word) == TY_SYM && get_sym(word) >= TY_U24 && get_sym(word) <= TY_F24;
}

// Partial application
static inline Numb partial(Numb a, Numb b) {
return (b & ~0x1F) | get_sym(a);
}

// Cast a number to another type.
// The semantics are meant to spiritually resemble rust's numeric casts:
// - i24 <-> u24: is just reinterpretation of bits
// - f24 -> i24,
// f24 -> u24: casts to the "closest" integer representing this float,
// saturating if out of range and 0 if NaN
// - i24 -> f24,
// u24 -> f24: casts to the "closest" float representing this integer.
static inline Numb cast(Numb a, Numb b) {
if (get_sym(a) == TY_U24 && get_typ(b) == TY_U24) return b;
if (get_sym(a) == TY_U24 && get_typ(b) == TY_I24) {
// reinterpret bits
i32 val = get_i24(b);
return new_u24(*(u32*) &val);
}
if (get_sym(a) == TY_U24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_u24(0);
}
return new_u24((u32) clamp(val, 0.0, 16777215));
}

if (get_sym(a) == TY_I24 && get_typ(b) == TY_U24) {
// reinterpret bits
u32 val = get_u24(b);
return new_i24(*(i32*) &val);
}
if (get_sym(a) == TY_I24 && get_typ(b) == TY_I24) return b;
if (get_sym(a) == TY_I24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_i24(0);
}
return new_i24((i32) clamp(val, -8388608.0, 8388607.0));
}

if (get_sym(a) == TY_F24 && get_typ(b) == TY_U24) return new_f24((f32) get_u24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_I24) return new_f24((f32) get_i24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_F24) return b;

return new_u24(0);
}

// Operate function
static inline Numb operate(Numb a, Numb b) {
Tag at = get_typ(a);
Tag bt = get_typ(b);
if (at == TY_SYM && bt == TY_SYM) {
return new_u24(0);
}
if (is_cast(a) && is_num(b)) {
return cast(a, b);
}
if (is_cast(b) && is_num(a)) {
return cast(b, a);
}
if (at == TY_SYM && bt != TY_SYM) {
return partial(a, b);
}
Expand Down Expand Up @@ -1776,7 +1841,7 @@ void pretty_print_numb(Numb word) {
} else if (isnan(get_f24(word))) {
printf("+NaN");
} else {
printf("%f", get_f24(word));
printf("%.7e", get_f24(word));
}
break;
}
Expand Down
65 changes: 64 additions & 1 deletion src/hvm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ __global__ void print_heatmap(GNet* gnet, u32 turn);
// Utils
// -----

__device__ __host__ f32 clamp(f32 x, f32 min, f32 max) {
const f32 t = x < min ? min : x;
return (t > max) ? max : t;
}

// TODO: write a time64() function that returns the time as fast as possible as a u64
static inline u64 time64() {
struct timespec ts;
Expand Down Expand Up @@ -540,6 +545,58 @@ __device__ __host__ inline Tag get_typ(Numb word) {
return word & 0x1F;
}

__device__ __host__ inline bool is_num(Numb word) {
return get_typ(word) >= TY_U24 && get_typ(word) <= TY_F24;
}

__device__ __host__ inline bool is_cast(Numb word) {
return get_typ(word) == TY_SYM && get_sym(word) >= TY_U24 && get_sym(word) <= TY_F24;
}

// Cast a number to another type.
// The semantics are meant to spiritually resemble rust's numeric casts:
// - i24 <-> u24: is just reinterpretation of bits
// - f24 -> i24,
// f24 -> u24: casts to the "closest" integer representing this float,
// saturating if out of range and 0 if NaN
// - i24 -> f24,
// u24 -> f24: casts to the "closest" float representing this integer.
__device__ __host__ inline Numb cast(Numb a, Numb b) {
if (get_sym(a) == TY_U24 && get_typ(b) == TY_U24) return b;
if (get_sym(a) == TY_U24 && get_typ(b) == TY_I24) {
// reinterpret bits
i32 val = get_i24(b);
return new_u24(*(u32*) &val);
}
if (get_sym(a) == TY_U24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_u24(0);
}
return new_u24((u32) clamp(val, 0.0, 16777215));
}

if (get_sym(a) == TY_I24 && get_typ(b) == TY_U24) {
// reinterpret bits
u32 val = get_u24(b);
return new_i24(*(i32*) &val);
}
if (get_sym(a) == TY_I24 && get_typ(b) == TY_I24) return b;
if (get_sym(a) == TY_I24 && get_typ(b) == TY_F24) {
f32 val = get_f24(b);
if (isnan(val)) {
return new_i24(0);
}
return new_i24((i32) clamp(val, -8388608.0, 8388607.0));
}

if (get_sym(a) == TY_F24 && get_typ(b) == TY_U24) return new_f24((f32) get_u24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_I24) return new_f24((f32) get_i24(b));
if (get_sym(a) == TY_F24 && get_typ(b) == TY_F24) return b;

return new_u24(0);
}

// Partial application
__device__ __host__ inline Numb partial(Numb a, Numb b) {
return (b & ~0x1F) | get_sym(a);
Expand All @@ -552,6 +609,12 @@ __device__ __host__ inline Numb operate(Numb a, Numb b) {
if (at == TY_SYM && bt == TY_SYM) {
return new_u24(0);
}
if (is_cast(a) && is_num(b)) {
return cast(a, b);
}
if (is_cast(b) && is_num(a)) {
return cast(b, a);
}
if (at == TY_SYM && bt != TY_SYM) {
return partial(a, b);
}
Expand Down Expand Up @@ -2226,7 +2289,7 @@ __device__ void pretty_print_numb(Numb word) {
} else if (isnan(get_f24(word))) {
printf("+NaN");
} else {
printf("%f", get_f24(word));
printf("%.7e", get_f24(word));
}
break;
}
Expand Down
45 changes: 41 additions & 4 deletions src/hvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,47 @@ impl Numb {
}

// Gets the numeric type.

pub fn get_typ(&self) -> Tag {
return (self.0 & 0x1F) as Tag;
(self.0 & 0x1F) as Tag
}

pub fn is_num(&self) -> bool {
self.get_typ() >= TY_U24 && self.get_typ() <= TY_F24
}

// Flip flag.
pub fn is_cast(&self) -> bool {
self.get_typ() == TY_SYM && self.get_sym() >= TY_U24 && self.get_sym() <= TY_F24
}

// Partial application.
pub fn partial(a: Self, b: Self) -> Self {
return Numb((b.0 & !0x1F) | a.get_sym() as u32);
Numb((b.0 & !0x1F) | a.get_sym() as u32)
}

// Cast a number to another type.
// The semantics are meant to spiritually resemble rust's numeric casts:
// - i24 <-> u24: is just reinterpretation of bits
// - f24 -> i24,
// f24 -> u24: casts to the "closest" integer representing this float,
// saturating if out of range and 0 if NaN
// - i24 -> f24,
// u24 -> f24: casts to the "closest" float representing this integer.
pub fn cast(a: Self, b: Self) -> Self {
match (a.get_sym(), b.get_typ()) {
(TY_U24, TY_U24) => b,
(TY_U24, TY_I24) => Self::new_u24(b.get_i24() as u32),
(TY_U24, TY_F24) => Self::new_u24(b.get_f24().clamp(0.0, 16777215.0) as u32),
enricozb marked this conversation as resolved.
Show resolved Hide resolved

(TY_I24, TY_U24) => Self::new_i24(b.get_u24() as i32),
(TY_I24, TY_I24) => b,
(TY_I24, TY_F24) => Self::new_i24(b.get_f24().clamp(-8388608.0, 8388607.0) as i32),
enricozb marked this conversation as resolved.
Show resolved Hide resolved

(TY_F24, TY_U24) => Self::new_f24(b.get_u24() as f32),
(TY_F24, TY_I24) => Self::new_f24(b.get_i24() as f32),
(TY_F24, TY_F24) => b,
// invalid cast
(_, _) => Self::new_u24(0),
}
}

pub fn operate(a: Self, b: Self) -> Self {
Expand All @@ -285,6 +316,12 @@ impl Numb {
if at == TY_SYM && bt == TY_SYM {
return Numb::new_u24(0);
}
if a.is_cast() && b.is_num() {
return Numb::cast(a, b);
}
if b.is_cast() && a.is_num() {
return Numb::cast(b, a);
}
if at == TY_SYM && bt != TY_SYM {
return Numb::partial(a, b);
}
Expand Down
40 changes: 40 additions & 0 deletions tests/programs/numeric-casts.hvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@main = x & @tu0 ~ (* x)

// casting to u24
@tu0 = (* {n x}) & @tu1 ~ (* x) & 0 ~ $([to_u24] n) // 0
@tu1 = (* {n x}) & @tu2 ~ (* x) & 1234 ~ $([to_u24] n) // 1234
@tu2 = (* {n x}) & @tu3 ~ (* x) & +4321 ~ $([to_u24] n) // 4321
@tu3 = (* {n x}) & @tu4 ~ (* x) & -5678 ~ $([to_u24] n) // 16771538 (reinterprets bits)
@tu4 = (* {n x}) & @tu5 ~ (* x) & 2.8 ~ $([to_u24] n) // 2 (rounds to zero)
@tu5 = (* {n x}) & @tu6 ~ (* x) & -12.5 ~ $([to_u24] n) // 0 (saturates)
@tu6 = (* {n x}) & @tu7 ~ (* x) & 16777216.0 ~ $([to_u24] n) // 16777215 (saturates)
@tu7 = (* {n x}) & @tu8 ~ (* x) & +inf ~ $([to_u24] n) // 16777215 (saturates)
@tu8 = (* {n x}) & @tu9 ~ (* x) & -inf ~ $([to_u24] n) // 0 (saturates)
@tu9 = (* {n x}) & @ti0 ~ (* x) & +NaN ~ $([to_u24] n) // 0

// casting to i24
@ti0 = (* {n x}) & @ti1 ~ (* x) & 0 ~ $([to_i24] n) // +0
@ti1 = (* {n x}) & @ti2 ~ (* x) & 1234 ~ $([to_i24] n) // +1234
@ti2 = (* {n x}) & @ti3 ~ (* x) & +4321 ~ $([to_i24] n) // +4321
@ti3 = (* {n x}) & @ti4 ~ (* x) & -5678 ~ $([to_i24] n) // -5678
@ti4 = (* {n x}) & @ti5 ~ (* x) & 2.8 ~ $([to_i24] n) // +2 (rounds to zero)
@ti5 = (* {n x}) & @ti6 ~ (* x) & -12.7 ~ $([to_i24] n) // -12 (rounds to zero)
@ti6 = (* {n x}) & @ti7 ~ (* x) & 8388610.0 ~ $([to_i24] n) // +8388607 (saturates)
@ti7 = (* {n x}) & @ti8 ~ (* x) & -8388610.0 ~ $([to_i24] n) // -8388608 (saturates)
@ti8 = (* {n x}) & @ti9 ~ (* x) & +inf ~ $([to_i24] n) // +8388607 (saturates)
@ti9 = (* {n x}) & @ti10 ~ (* x) & -inf ~ $([to_i24] n) // -8388608 (saturates)
@ti10 = (* {n x}) & @tf0 ~ (* x) & +NaN ~ $([to_i24] n) // +0

// casting to f24
@tf0 = (* {n x}) & @tf1 ~ (* x) & +NaN ~ $([to_f24] n) // +NaN
@tf1 = (* {n x}) & @tf2 ~ (* x) & +inf ~ $([to_f24] n) // +inf
@tf2 = (* {n x}) & @tf3 ~ (* x) & -inf ~ $([to_f24] n) // -inf
@tf3 = (* {n x}) & @tf4 ~ (* x) & 2.15 ~ $([to_f24] n) // 2.15
@tf4 = (* {n x}) & @tf5 ~ (* x) & -2.15 ~ $([to_f24] n) // -2.15
@tf5 = (* {n x}) & @tf6 ~ (* x) & 0.15 ~ $([to_f24] n) // 0.15
@tf6 = (* {n x}) & @tf7 ~ (* x) & -1234 ~ $([to_f24] n) // -1234.0
@tf7 = (* {n x}) & @tf8 ~ (* x) & +1234 ~ $([to_f24] n) // +1234.0
@tf8 = (* {n x}) & @tf9 ~ (* x) & 123456 ~ $([to_f24] n) // 123456.0
@tf9 = (* {n x}) & @t ~ (* x) & 16775982 ~ $([to_f24] n) // 16775936.0

@t = *
2 changes: 1 addition & 1 deletion tests/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ expression: rust_output
input_file: tests/programs/empty.hvm
---
exit status: 101
thread 'main' panicked at src/ast.rs:508:41:
thread 'main' panicked at src/ast.rs:527:41:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

off-topic but this diff is pretty annoying; we should figure out a better way to test this

missing `@main` definition
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
2 changes: 1 addition & 1 deletion tests/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ source: tests/run.rs
expression: rust_output
input_file: tests/programs/f24.hvm
---
Result: {+inf {-inf {+NaN {2.5 {-1.5 {1.1499939 {0.25 {0.5 {0 {1 {1 {0 {0 {0 {0 {+NaN {+inf {-inf {1.019989 {0.1000061 {0.1000061 {-0.1000061 {-0.1000061 *}}}}}}}}}}}}}}}}}}}}}}}
Result: {+inf {-inf {+NaN {2.5000000e0 {-1.5000000e0 {1.1499939e0 {2.5000000e-1 {5.0000000e-1 {0 {1 {1 {0 {0 {0 {0 {+NaN {+inf {-inf {1.0199890e0 {1.0000610e-1 {1.0000610e-1 {-1.0000610e-1 {-1.0000610e-1 *}}}}}}}}}}}}}}}}}}}}}}}
6 changes: 6 additions & 0 deletions tests/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
source: tests/run.rs
expression: rust_output
input_file: tests/programs/numeric-casts.hvm
---
Result: {0 {1234 {4321 {16771538 {2 {0 {16777215 {16777215 {0 {0 {+0 {+1234 {+4321 {-5678 {+2 {-12 {+8388607 {-8388608 {+8388607 {-8388608 {+0 {+NaN {+inf {-inf {2.1500244e0 {-2.1500244e0 {1.5000153e-1 {-1.2340000e3 {1.2340000e3 {1.2345600e5 {1.6775936e7 *}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}
Loading