diff --git a/src/ast.rs b/src/ast.rs index a8afd8bd..2d6c0370 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -43,8 +43,22 @@ impl<'i> CoreParser<'i> { pub fn parse_numb_sym(&mut self) -> Result { self.consume("[")?; + // numeric casts + if let Some(cast) = match () { + _ if self.try_consume("u24") => Some(hvm::TY_U24), + _ if self.try_consume("i24") => Some(hvm::TY_I24), + _ if self.try_consume("f24") => Some(hvm::TY_F24), + _ => None + } { + // Casts can't be partially applied, so nothing should follow. + self.consume("]")?; + + return Ok(Numb(hvm::Numb::new_sym(cast).0)); + } + // Parses the symbol let op = hvm::Numb::new_sym(match () { + // numeric operations _ if self.try_consume("+") => hvm::OP_ADD, _ if self.try_consume("-") => hvm::OP_SUB, _ if self.try_consume(":-") => hvm::FP_SUB, @@ -224,6 +238,11 @@ impl Numb { let numb = hvm::Numb(self.0); match numb.get_typ() { hvm::TY_SYM => match numb.get_sym() as hvm::Tag { + // casts + hvm::TY_U24 => "[u24]".to_string(), + hvm::TY_I24 => "[i24]".to_string(), + hvm::TY_F24 => "[f24]".to_string(), + // operations hvm::OP_ADD => "[+]".to_string(), hvm::OP_SUB => "[-]".to_string(), hvm::FP_SUB => "[:-]".to_string(), diff --git a/src/hvm.c b/src/hvm.c index 28b7fafa..fa9caa28 100644 --- a/src/hvm.c +++ b/src/hvm.c @@ -21,6 +21,8 @@ typedef uint16_t u16; typedef int32_t i32; typedef uint32_t u32; typedef uint64_t u64; +typedef float f32; +typedef double f64; typedef _Atomic(u8) a8; typedef _Atomic(u16) a16; @@ -75,6 +77,10 @@ typedef u32 Numb; // Numb ::= 29-bit (rounded up to u32) #define SWIT 0x7 // Numbers +static const f32 U24_MAX = (f32) (1 << 24) - 1; +static const f32 U24_MIN = 0.0; +static const f32 I24_MAX = (f32) (1 << 23) - 1; +static const f32 I24_MIN = (f32) (i32) ((-1u) << 23); #define TY_SYM 0x00 #define TY_U24 0x01 #define TY_I24 0x02 @@ -278,10 +284,15 @@ static inline void swap(Port *a, Port *b) { Port x = *a; *a = *b; *b = x; } -u32 min(u32 a, u32 b) { +inline u32 min(u32 a, u32 b) { return (a < b) ? a : b; } +inline f32 clamp(f32 x, f32 min, f32 max) { + const f32 t = x < min ? min : x; + return (t > max) ? max : t; +} + // A simple spin-wait barrier using atomic operations a64 a_reached = 0; // number of threads that reached the current barrier a64 a_barrier = 0; // number of barriers passed during this program @@ -429,11 +440,63 @@ static inline Tag get_typ(Numb word) { return word & 0x1F; } +static inline bool is_num(Numb word) { + return get_typ(word) >= TY_U24 && get_typ(word) <= TY_F24; +} + +static inline bool is_cast(Numb word) { + return get_typ(word) == TY_SYM && get_sym(word) >= TY_U24 && get_sym(word) <= TY_F24; +} + // Partial application static inline Numb partial(Numb a, Numb b) { return (b & ~0x1F) | get_sym(a); } +// Cast a number to another type. +// The semantics are meant to spiritually resemble rust's numeric casts: +// - i24 <-> u24: is just reinterpretation of bits +// - f24 -> i24, +// f24 -> u24: casts to the "closest" integer representing this float, +// saturating if out of range and 0 if NaN +// - i24 -> f24, +// u24 -> f24: casts to the "closest" float representing this integer. +static inline Numb cast(Numb a, Numb b) { + if (get_sym(a) == TY_U24 && get_typ(b) == TY_U24) return b; + if (get_sym(a) == TY_U24 && get_typ(b) == TY_I24) { + // reinterpret bits + i32 val = get_i24(b); + return new_u24(*(u32*) &val); + } + if (get_sym(a) == TY_U24 && get_typ(b) == TY_F24) { + f32 val = get_f24(b); + if (isnan(val)) { + return new_u24(0); + } + return new_u24((u32) clamp(val, U24_MIN, U24_MAX)); + } + + if (get_sym(a) == TY_I24 && get_typ(b) == TY_U24) { + // reinterpret bits + u32 val = get_u24(b); + return new_i24(*(i32*) &val); + } + if (get_sym(a) == TY_I24 && get_typ(b) == TY_I24) return b; + if (get_sym(a) == TY_I24 && get_typ(b) == TY_F24) { + f32 val = get_f24(b); + if (isnan(val)) { + return new_i24(0); + } + return new_i24((i32) clamp(val, I24_MIN, I24_MAX)); + } + + if (get_sym(a) == TY_F24 && get_typ(b) == TY_U24) return new_f24((f32) get_u24(b)); + if (get_sym(a) == TY_F24 && get_typ(b) == TY_I24) return new_f24((f32) get_i24(b)); + if (get_sym(a) == TY_F24 && get_typ(b) == TY_F24) return b; + + return new_u24(0); +} + // Operate function static inline Numb operate(Numb a, Numb b) { Tag at = get_typ(a); @@ -441,6 +504,12 @@ static inline Numb operate(Numb a, Numb b) { if (at == TY_SYM && bt == TY_SYM) { return new_u24(0); } + if (is_cast(a) && is_num(b)) { + return cast(a, b); + } + if (is_cast(b) && is_num(a)) { + return cast(b, a); + } if (at == TY_SYM && bt != TY_SYM) { return partial(a, b); } @@ -1735,6 +1804,11 @@ void pretty_print_numb(Numb word) { switch (get_typ(word)) { case TY_SYM: { switch (get_sym(word)) { + // types + case TY_U24: printf("[u24]"); break; + case TY_I24: printf("[i24]"); break; + case TY_F24: printf("[f24]"); break; + // operations case OP_ADD: printf("[+]"); break; case OP_SUB: printf("[-]"); break; case FP_SUB: printf("[:-]"); break; @@ -1776,7 +1850,7 @@ void pretty_print_numb(Numb word) { } else if (isnan(get_f24(word))) { printf("+NaN"); } else { - printf("%f", get_f24(word)); + printf("%.7e", get_f24(word)); } break; } diff --git a/src/hvm.cu b/src/hvm.cu index ed8cc2ed..b824e7be 100644 --- a/src/hvm.cu +++ b/src/hvm.cu @@ -261,6 +261,11 @@ __global__ void print_heatmap(GNet* gnet, u32 turn); // Utils // ----- +__device__ __host__ f32 clamp(f32 x, f32 min, f32 max) { + const f32 t = x < min ? min : x; + return (t > max) ? max : t; +} + // TODO: write a time64() function that returns the time as fast as possible as a u64 static inline u64 time64() { struct timespec ts; @@ -540,6 +545,58 @@ __device__ __host__ inline Tag get_typ(Numb word) { return word & 0x1F; } +__device__ __host__ inline bool is_num(Numb word) { + return get_typ(word) >= TY_U24 && get_typ(word) <= TY_F24; +} + +__device__ __host__ inline bool is_cast(Numb word) { + return get_typ(word) == TY_SYM && get_sym(word) >= TY_U24 && get_sym(word) <= TY_F24; +} + +// Cast a number to another type. +// The semantics are meant to spiritually resemble rust's numeric casts: +// - i24 <-> u24: is just reinterpretation of bits +// - f24 -> i24, +// f24 -> u24: casts to the "closest" integer representing this float, +// saturating if out of range and 0 if NaN +// - i24 -> f24, +// u24 -> f24: casts to the "closest" float representing this integer. +__device__ __host__ inline Numb cast(Numb a, Numb b) { + if (get_sym(a) == TY_U24 && get_typ(b) == TY_U24) return b; + if (get_sym(a) == TY_U24 && get_typ(b) == TY_I24) { + // reinterpret bits + i32 val = get_i24(b); + return new_u24(*(u32*) &val); + } + if (get_sym(a) == TY_U24 && get_typ(b) == TY_F24) { + f32 val = get_f24(b); + if (isnan(val)) { + return new_u24(0); + } + return new_u24((u32) clamp(val, 0.0, 16777215)); + } + + if (get_sym(a) == TY_I24 && get_typ(b) == TY_U24) { + // reinterpret bits + u32 val = get_u24(b); + return new_i24(*(i32*) &val); + } + if (get_sym(a) == TY_I24 && get_typ(b) == TY_I24) return b; + if (get_sym(a) == TY_I24 && get_typ(b) == TY_F24) { + f32 val = get_f24(b); + if (isnan(val)) { + return new_i24(0); + } + return new_i24((i32) clamp(val, -8388608.0, 8388607.0)); + } + + if (get_sym(a) == TY_F24 && get_typ(b) == TY_U24) return new_f24((f32) get_u24(b)); + if (get_sym(a) == TY_F24 && get_typ(b) == TY_I24) return new_f24((f32) get_i24(b)); + if (get_sym(a) == TY_F24 && get_typ(b) == TY_F24) return b; + + return new_u24(0); +} + // Partial application __device__ __host__ inline Numb partial(Numb a, Numb b) { return (b & ~0x1F) | get_sym(a); @@ -552,6 +609,12 @@ __device__ __host__ inline Numb operate(Numb a, Numb b) { if (at == TY_SYM && bt == TY_SYM) { return new_u24(0); } + if (is_cast(a) && is_num(b)) { + return cast(a, b); + } + if (is_cast(b) && is_num(a)) { + return cast(b, a); + } if (at == TY_SYM && bt != TY_SYM) { return partial(a, b); } @@ -2185,6 +2248,11 @@ __device__ void pretty_print_numb(Numb word) { switch (get_typ(word)) { case TY_SYM: { switch (get_sym(word)) { + // types + case TY_U24: printf("[u24]"); break; + case TY_I24: printf("[i24]"); break; + case TY_F24: printf("[f24]"); break; + // operations case OP_ADD: printf("[+]"); break; case OP_SUB: printf("[-]"); break; case FP_SUB: printf("[:-]"); break; @@ -2226,7 +2294,7 @@ __device__ void pretty_print_numb(Numb word) { } else if (isnan(get_f24(word))) { printf("+NaN"); } else { - printf("%f", get_f24(word)); + printf("%.7e", get_f24(word)); } break; } diff --git a/src/hvm.rs b/src/hvm.rs index 5d8cd9f2..d7007aad 100644 --- a/src/hvm.rs +++ b/src/hvm.rs @@ -25,6 +25,10 @@ pub struct APair(pub AtomicU64); // Number pub struct Numb(pub Val); +const U24_MAX : u32 = (1 << 24) - 1; +const U24_MIN : u32 = 0; +const I24_MAX : i32 = (1 << 23) - 1; +const I24_MIN : i32 = (-1) << 23; // Tags pub const VAR : Tag = 0x0; // variable @@ -266,16 +270,47 @@ impl Numb { } // Gets the numeric type. - pub fn get_typ(&self) -> Tag { - return (self.0 & 0x1F) as Tag; + (self.0 & 0x1F) as Tag + } + + pub fn is_num(&self) -> bool { + self.get_typ() >= TY_U24 && self.get_typ() <= TY_F24 } - // Flip flag. + pub fn is_cast(&self) -> bool { + self.get_typ() == TY_SYM && self.get_sym() >= TY_U24 && self.get_sym() <= TY_F24 + } // Partial application. pub fn partial(a: Self, b: Self) -> Self { - return Numb((b.0 & !0x1F) | a.get_sym() as u32); + Numb((b.0 & !0x1F) | a.get_sym() as u32) + } + + // Cast a number to another type. + // The semantics are meant to spiritually resemble rust's numeric casts: + // - i24 <-> u24: is just reinterpretation of bits + // - f24 -> i24, + // f24 -> u24: casts to the "closest" integer representing this float, + // saturating if out of range and 0 if NaN + // - i24 -> f24, + // u24 -> f24: casts to the "closest" float representing this integer. + pub fn cast(a: Self, b: Self) -> Self { + match (a.get_sym(), b.get_typ()) { + (TY_U24, TY_U24) => b, + (TY_U24, TY_I24) => Self::new_u24(b.get_i24() as u32), + (TY_U24, TY_F24) => Self::new_u24((b.get_f24() as u32).clamp(U24_MIN, U24_MAX)), + + (TY_I24, TY_U24) => Self::new_i24(b.get_u24() as i32), + (TY_I24, TY_I24) => b, + (TY_I24, TY_F24) => Self::new_i24((b.get_f24() as i32).clamp(I24_MIN, I24_MAX)), + + (TY_F24, TY_U24) => Self::new_f24(b.get_u24() as f32), + (TY_F24, TY_I24) => Self::new_f24(b.get_i24() as f32), + (TY_F24, TY_F24) => b, + // invalid cast + (_, _) => Self::new_u24(0), + } } pub fn operate(a: Self, b: Self) -> Self { @@ -285,6 +320,12 @@ impl Numb { if at == TY_SYM && bt == TY_SYM { return Numb::new_u24(0); } + if a.is_cast() && b.is_num() { + return Numb::cast(a, b); + } + if b.is_cast() && a.is_num() { + return Numb::cast(b, a); + } if at == TY_SYM && bt != TY_SYM { return Numb::partial(a, b); } diff --git a/tests/programs/numeric-casts.hvm b/tests/programs/numeric-casts.hvm new file mode 100644 index 00000000..16df5e65 --- /dev/null +++ b/tests/programs/numeric-casts.hvm @@ -0,0 +1,45 @@ +@main = x & @tu0 ~ (* x) + +// casting to u24 +@tu0 = (* {n x}) & @tu1 ~ (* x) & 0 ~ $([u24] n) // 0 +@tu1 = (* {n x}) & @tu2 ~ (* x) & 1234 ~ $([u24] n) // 1234 +@tu2 = (* {n x}) & @tu3 ~ (* x) & +4321 ~ $([u24] n) // 4321 +@tu3 = (* {n x}) & @tu4 ~ (* x) & -5678 ~ $([u24] n) // 16771538 (reinterprets bits) +@tu4 = (* {n x}) & @tu5 ~ (* x) & 2.8 ~ $([u24] n) // 2 (rounds to zero) +@tu5 = (* {n x}) & @tu6 ~ (* x) & -12.5 ~ $([u24] n) // 0 (saturates) +@tu6 = (* {n x}) & @tu7 ~ (* x) & 16777216.0 ~ $([u24] n) // 16777215 (saturates) +@tu7 = (* {n x}) & @tu8 ~ (* x) & +inf ~ $([u24] n) // 16777215 (saturates) +@tu8 = (* {n x}) & @tu9 ~ (* x) & -inf ~ $([u24] n) // 0 (saturates) +@tu9 = (* {n x}) & @ti0 ~ (* x) & +NaN ~ $([u24] n) // 0 + +// casting to i24 +@ti0 = (* {n x}) & @ti1 ~ (* x) & 0 ~ $([i24] n) // +0 +@ti1 = (* {n x}) & @ti2 ~ (* x) & 1234 ~ $([i24] n) // +1234 +@ti2 = (* {n x}) & @ti3 ~ (* x) & +4321 ~ $([i24] n) // +4321 +@ti3 = (* {n x}) & @ti4 ~ (* x) & -5678 ~ $([i24] n) // -5678 +@ti4 = (* {n x}) & @ti5 ~ (* x) & 2.8 ~ $([i24] n) // +2 (rounds to zero) +@ti5 = (* {n x}) & @ti6 ~ (* x) & -12.7 ~ $([i24] n) // -12 (rounds to zero) +@ti6 = (* {n x}) & @ti7 ~ (* x) & 8388610.0 ~ $([i24] n) // +8388607 (saturates) +@ti7 = (* {n x}) & @ti8 ~ (* x) & -8388610.0 ~ $([i24] n) // -8388608 (saturates) +@ti8 = (* {n x}) & @ti9 ~ (* x) & +inf ~ $([i24] n) // +8388607 (saturates) +@ti9 = (* {n x}) & @ti10 ~ (* x) & -inf ~ $([i24] n) // -8388608 (saturates) +@ti10 = (* {n x}) & @tf0 ~ (* x) & +NaN ~ $([i24] n) // +0 + +// casting to f24 +@tf0 = (* {n x}) & @tf1 ~ (* x) & +NaN ~ $([f24] n) // +NaN +@tf1 = (* {n x}) & @tf2 ~ (* x) & +inf ~ $([f24] n) // +inf +@tf2 = (* {n x}) & @tf3 ~ (* x) & -inf ~ $([f24] n) // -inf +@tf3 = (* {n x}) & @tf4 ~ (* x) & 2.15 ~ $([f24] n) // 2.15 +@tf4 = (* {n x}) & @tf5 ~ (* x) & -2.15 ~ $([f24] n) // -2.15 +@tf5 = (* {n x}) & @tf6 ~ (* x) & 0.15 ~ $([f24] n) // 0.15 +@tf6 = (* {n x}) & @tf7 ~ (* x) & -1234 ~ $([f24] n) // -1234.0 +@tf7 = (* {n x}) & @tf8 ~ (* x) & +1234 ~ $([f24] n) // +1234.0 +@tf8 = (* {n x}) & @tf9 ~ (* x) & 123456 ~ $([f24] n) // 123456.0 +@tf9 = (* {n x}) & @tp0 ~ (* x) & 16775982 ~ $([f24] n) // 16775936.0 + +// printing +@tp0 = (* {n x}) & @tp1 ~ (* x) & n ~ [u24] // [u24] +@tp1 = (* {n x}) & @tp2 ~ (* x) & n ~ [i24] // [i24] +@tp2 = (* {n x}) & @t ~ (* x) & n ~ [f24] // [f24] + +@t = * diff --git a/tests/snapshots/run__file@empty.hvm.snap b/tests/snapshots/run__file@empty.hvm.snap index b5b4735e..953bfae9 100644 --- a/tests/snapshots/run__file@empty.hvm.snap +++ b/tests/snapshots/run__file@empty.hvm.snap @@ -4,6 +4,6 @@ expression: rust_output input_file: tests/programs/empty.hvm --- exit status: 101 -thread 'main' panicked at src/ast.rs:508:41: +thread 'main' panicked at src/ast.rs:527:41: missing `@main` definition note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace diff --git a/tests/snapshots/run__file@numeric-casts.hvm.snap b/tests/snapshots/run__file@numeric-casts.hvm.snap new file mode 100644 index 00000000..cc1dd365 --- /dev/null +++ b/tests/snapshots/run__file@numeric-casts.hvm.snap @@ -0,0 +1,6 @@ +--- +source: tests/run.rs +expression: rust_output +input_file: tests/programs/numeric-casts.hvm +--- +Result: {0 {1234 {4321 {16771538 {2 {0 {16777215 {16777215 {0 {0 {+0 {+1234 {+4321 {-5678 {+2 {-12 {+8388607 {-8388608 {+8388607 {-8388608 {+0 {+NaN {+inf {-inf {2.1500244 {-2.1500244 {0.15000153 {-1234.0 {1234.0 {123456.0 {16775936.0 {[u24] {[i24] {[f24] *}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}