diff --git a/crates/proof-of-sql/benches/bench_append_rows.rs b/crates/proof-of-sql/benches/bench_append_rows.rs index 023c04cff..81855cd05 100644 --- a/crates/proof-of-sql/benches/bench_append_rows.rs +++ b/crates/proof-of-sql/benches/bench_append_rows.rs @@ -14,7 +14,7 @@ use proof_of_sql::{ database::{ owned_table_utility::{ bigint, boolean, decimal75, int, int128, owned_table, scalar, smallint, - timestamptz, varchar, + timestamptz, tinyint, varchar, }, OwnedTable, }, @@ -86,6 +86,7 @@ pub fn generate_random_owned_table( "scalar", "varchar", "decimal75", + "tinyint", "smallint", "int", "timestamptz", @@ -118,6 +119,7 @@ pub fn generate_random_owned_table( 2, vec![generate_random_u64_array(); num_rows], )), + "tinyint" => columns.push(tinyint(identifier.deref(), vec![rng.gen::(); num_rows])), "smallint" => columns.push(smallint( identifier.deref(), vec![rng.gen::(); num_rows], diff --git a/crates/proof-of-sql/src/base/commitment/column_bounds.rs b/crates/proof-of-sql/src/base/commitment/column_bounds.rs index c4fcdbabe..ce6328610 100644 --- a/crates/proof-of-sql/src/base/commitment/column_bounds.rs +++ b/crates/proof-of-sql/src/base/commitment/column_bounds.rs @@ -203,6 +203,8 @@ pub struct ColumnBoundsMismatch { pub enum ColumnBounds { /// Column does not have order. NoOrder, + /// The bounds of a `TinyInt` column. + TinyInt(Bounds), /// The bounds of a `SmallInt` column. SmallInt(Bounds), /// The bounds of an Int column. @@ -222,6 +224,7 @@ impl ColumnBounds { #[must_use] pub fn from_column(column: &CommittableColumn) -> ColumnBounds { match column { + CommittableColumn::TinyInt(ints) => ColumnBounds::TinyInt(Bounds::from_iter(*ints)), CommittableColumn::SmallInt(ints) => ColumnBounds::SmallInt(Bounds::from_iter(*ints)), CommittableColumn::Int(ints) => ColumnBounds::Int(Bounds::from_iter(*ints)), CommittableColumn::BigInt(ints) => ColumnBounds::BigInt(Bounds::from_iter(*ints)), @@ -243,6 +246,9 @@ impl ColumnBounds { pub fn try_union(self, other: Self) -> Result { match (self, other) { (ColumnBounds::NoOrder, ColumnBounds::NoOrder) => Ok(ColumnBounds::NoOrder), + (ColumnBounds::TinyInt(bounds_a), ColumnBounds::TinyInt(bounds_b)) => { + Ok(ColumnBounds::TinyInt(bounds_a.union(bounds_b))) + } (ColumnBounds::SmallInt(bounds_a), ColumnBounds::SmallInt(bounds_b)) => { Ok(ColumnBounds::SmallInt(bounds_a.union(bounds_b))) } @@ -272,6 +278,9 @@ impl ColumnBounds { pub fn try_difference(self, other: Self) -> Result { match (self, other) { (ColumnBounds::NoOrder, ColumnBounds::NoOrder) => Ok(self), + (ColumnBounds::TinyInt(bounds_a), ColumnBounds::TinyInt(bounds_b)) => { + Ok(ColumnBounds::TinyInt(bounds_a.difference(bounds_b))) + } (ColumnBounds::SmallInt(bounds_a), ColumnBounds::SmallInt(bounds_b)) => { Ok(ColumnBounds::SmallInt(bounds_a.difference(bounds_b))) } @@ -497,6 +506,14 @@ mod tests { let varchar_column_bounds = ColumnBounds::from_column(&committable_varchar_column); assert_eq!(varchar_column_bounds, ColumnBounds::NoOrder); + let tinyint_column = OwnedColumn::::TinyInt([1, 2, 3, 1, 0].to_vec()); + let committable_tinyint_column = CommittableColumn::from(&tinyint_column); + let tinyint_column_bounds = ColumnBounds::from_column(&committable_tinyint_column); + assert_eq!( + tinyint_column_bounds, + ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: 0, max: 3 })) + ); + let smallint_column = OwnedColumn::::SmallInt([1, 2, 3, 1, 0].to_vec()); let committable_smallint_column = CommittableColumn::from(&smallint_column); let smallint_column_bounds = ColumnBounds::from_column(&committable_smallint_column); @@ -560,6 +577,13 @@ mod tests { let no_order = ColumnBounds::NoOrder; assert_eq!(no_order.try_union(no_order).unwrap(), no_order); + let tinyint_a = ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); + let tinyint_b = ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: 4, max: 6 })); + assert_eq!( + tinyint_a.try_union(tinyint_b).unwrap(), + ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: 1, max: 6 })) + ); + let smallint_a = ColumnBounds::SmallInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); let smallint_b = ColumnBounds::SmallInt(Bounds::Sharp(BoundsInner { min: 4, max: 6 })); assert_eq!( @@ -607,6 +631,7 @@ mod tests { #[test] fn we_cannot_union_mismatched_column_bounds() { let no_order = ColumnBounds::NoOrder; + let tinyint = ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: -3, max: 3 })); let smallint = ColumnBounds::SmallInt(Bounds::Sharp(BoundsInner { min: -5, max: 5 })); let int = ColumnBounds::Int(Bounds::Sharp(BoundsInner { min: -10, max: 10 })); let bigint = ColumnBounds::BigInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); @@ -615,6 +640,7 @@ mod tests { let bounds = [ (no_order, "NoOrder"), + (tinyint, "TinyInt"), (smallint, "SmallInt"), (int, "Int"), (bigint, "BigInt"), @@ -639,6 +665,10 @@ mod tests { let no_order = ColumnBounds::NoOrder; assert_eq!(no_order.try_difference(no_order).unwrap(), no_order); + let tinyint_a = ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); + let tinyint_b = ColumnBounds::TinyInt(Bounds::Empty); + assert_eq!(tinyint_a.try_difference(tinyint_b).unwrap(), tinyint_a); + let smallint_a = ColumnBounds::SmallInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); let smallint_b = ColumnBounds::SmallInt(Bounds::Empty); assert_eq!(smallint_a.try_difference(smallint_b).unwrap(), smallint_a); @@ -672,6 +702,7 @@ mod tests { let bigint = ColumnBounds::BigInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); let int128 = ColumnBounds::Int128(Bounds::Sharp(BoundsInner { min: 4, max: 6 })); let timestamp = ColumnBounds::TimestampTZ(Bounds::Sharp(BoundsInner { min: 4, max: 6 })); + let tinyint = ColumnBounds::TinyInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); let smallint = ColumnBounds::SmallInt(Bounds::Sharp(BoundsInner { min: 1, max: 3 })); assert!(no_order.try_difference(bigint).is_err()); @@ -683,6 +714,9 @@ mod tests { assert!(bigint.try_difference(int128).is_err()); assert!(int128.try_difference(bigint).is_err()); + assert!(tinyint.try_difference(timestamp).is_err()); + assert!(timestamp.try_difference(tinyint).is_err()); + assert!(smallint.try_difference(timestamp).is_err()); assert!(timestamp.try_difference(smallint).is_err()); } diff --git a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs index 5affbc78a..263e76c68 100644 --- a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs +++ b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs @@ -44,7 +44,8 @@ impl ColumnCommitmentMetadata { bounds: ColumnBounds, ) -> Result { match (column_type, bounds) { - (ColumnType::SmallInt, ColumnBounds::SmallInt(_)) + (ColumnType::TinyInt, ColumnBounds::TinyInt(_)) + | (ColumnType::SmallInt, ColumnBounds::SmallInt(_)) | (ColumnType::Int, ColumnBounds::Int(_)) | (ColumnType::BigInt, ColumnBounds::BigInt(_)) | (ColumnType::Int128, ColumnBounds::Int128(_)) @@ -189,6 +190,18 @@ mod tests { #[test] fn we_can_construct_metadata() { + assert_eq!( + ColumnCommitmentMetadata::try_new( + ColumnType::TinyInt, + ColumnBounds::TinyInt(Bounds::Empty) + ) + .unwrap(), + ColumnCommitmentMetadata { + column_type: ColumnType::TinyInt, + bounds: ColumnBounds::TinyInt(Bounds::Empty) + } + ); + assert_eq!( ColumnCommitmentMetadata::try_new( ColumnType::SmallInt, @@ -436,6 +449,17 @@ mod tests { panic!("Bounds constructed from nonempty BigInt column should be ColumnBounds::Int(Bounds::Sharp(_))"); } + let tinyint_column = OwnedColumn::::TinyInt([1, 2, 3, 1, 0].to_vec()); + let committable_tinyint_column = CommittableColumn::from(&tinyint_column); + let tinyint_metadata = ColumnCommitmentMetadata::from_column(&committable_tinyint_column); + assert_eq!(tinyint_metadata.column_type(), &ColumnType::TinyInt); + if let ColumnBounds::TinyInt(Bounds::Sharp(bounds)) = tinyint_metadata.bounds() { + assert_eq!(bounds.min(), &0); + assert_eq!(bounds.max(), &3); + } else { + panic!("Bounds constructed from nonempty BigInt column should be ColumnBounds::TinyInt(Bounds::Sharp(_))"); + } + let smallint_column = OwnedColumn::::SmallInt([1, 2, 3, 1, 0].to_vec()); let committable_smallint_column = CommittableColumn::from(&smallint_column); let smallint_metadata = ColumnCommitmentMetadata::from_column(&committable_smallint_column); @@ -504,6 +528,18 @@ mod tests { ); // Ordered case + let ints = [1, 2, 3, 1, 0]; + let tinyint_column_a = CommittableColumn::TinyInt(&ints[..2]); + let tinyint_metadata_a = ColumnCommitmentMetadata::from_column(&tinyint_column_a); + let tinyint_column_b = CommittableColumn::TinyInt(&ints[2..]); + let tinyint_metadata_b = ColumnCommitmentMetadata::from_column(&tinyint_column_b); + let tinyint_column_c = CommittableColumn::TinyInt(&ints); + let tinyint_metadata_c = ColumnCommitmentMetadata::from_column(&tinyint_column_c); + assert_eq!( + tinyint_metadata_a.try_union(tinyint_metadata_b).unwrap(), + tinyint_metadata_c + ); + let ints = [1, 2, 3, 1, 0]; let smallint_column_a = CommittableColumn::SmallInt(&ints[..2]); let smallint_metadata_a = ColumnCommitmentMetadata::from_column(&smallint_column_a); @@ -650,6 +686,43 @@ mod tests { ); } + #[test] + fn we_can_difference_tinyint_matching_metadata() { + // Ordered case + let tinyints = [1, 2, 3, 1, 0]; + let tinyint_column_a = CommittableColumn::TinyInt(&tinyints[..2]); + let tinyint_metadata_a = ColumnCommitmentMetadata::from_column(&tinyint_column_a); + let tinyint_column_b = CommittableColumn::TinyInt(&tinyints); + let tinyint_metadata_b = ColumnCommitmentMetadata::from_column(&tinyint_column_b); + + let b_difference_a = tinyint_metadata_b + .try_difference(tinyint_metadata_a) + .unwrap(); + assert_eq!(b_difference_a.column_type, ColumnType::TinyInt); + if let ColumnBounds::TinyInt(Bounds::Bounded(bounds)) = b_difference_a.bounds() { + assert_eq!(bounds.min(), &0); + assert_eq!(bounds.max(), &3); + } else { + panic!("difference of overlapping bounds should be Bounded"); + } + + let tinyint_column_empty = CommittableColumn::TinyInt(&[]); + let tinyint_metadata_empty = ColumnCommitmentMetadata::from_column(&tinyint_column_empty); + + assert_eq!( + tinyint_metadata_b + .try_difference(tinyint_metadata_empty) + .unwrap(), + tinyint_metadata_b + ); + assert_eq!( + tinyint_metadata_empty + .try_difference(tinyint_metadata_b) + .unwrap(), + tinyint_metadata_empty + ); + } + #[test] fn we_can_difference_smallint_matching_metadata() { // Ordered case @@ -732,6 +805,10 @@ mod tests { column_type: ColumnType::Scalar, bounds: ColumnBounds::NoOrder, }; + let tinyint_metadata = ColumnCommitmentMetadata { + column_type: ColumnType::TinyInt, + bounds: ColumnBounds::TinyInt(Bounds::Empty), + }; let smallint_metadata = ColumnCommitmentMetadata { column_type: ColumnType::SmallInt, bounds: ColumnBounds::SmallInt(Bounds::Empty), @@ -753,6 +830,18 @@ mod tests { bounds: ColumnBounds::Int128(Bounds::Empty), }; + assert!(tinyint_metadata.try_union(scalar_metadata).is_err()); + assert!(scalar_metadata.try_union(tinyint_metadata).is_err()); + + assert!(tinyint_metadata.try_union(decimal75_metadata).is_err()); + assert!(decimal75_metadata.try_union(tinyint_metadata).is_err()); + + assert!(tinyint_metadata.try_union(varchar_metadata).is_err()); + assert!(varchar_metadata.try_union(tinyint_metadata).is_err()); + + assert!(tinyint_metadata.try_union(boolean_metadata).is_err()); + assert!(boolean_metadata.try_union(tinyint_metadata).is_err()); + assert!(smallint_metadata.try_union(scalar_metadata).is_err()); assert!(scalar_metadata.try_union(smallint_metadata).is_err()); diff --git a/crates/proof-of-sql/src/base/commitment/committable_column.rs b/crates/proof-of-sql/src/base/commitment/committable_column.rs index 79a087789..58f937a09 100644 --- a/crates/proof-of-sql/src/base/commitment/committable_column.rs +++ b/crates/proof-of-sql/src/base/commitment/committable_column.rs @@ -25,6 +25,8 @@ use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; pub enum CommittableColumn<'a> { /// Borrowed Bool column, mapped to `bool`. Boolean(&'a [bool]), + /// Borrowed `TinyInt` column, mapped to `i8`. + TinyInt(&'a [i8]), /// Borrowed `SmallInt` column, mapped to `i16`. SmallInt(&'a [i16]), /// Borrowed `Int` column, mapped to `i32`. @@ -51,6 +53,7 @@ impl<'a> CommittableColumn<'a> { #[must_use] pub fn len(&self) -> usize { match self { + CommittableColumn::TinyInt(col) => col.len(), CommittableColumn::SmallInt(col) => col.len(), CommittableColumn::Int(col) => col.len(), CommittableColumn::BigInt(col) | CommittableColumn::TimestampTZ(_, _, col) => col.len(), @@ -79,6 +82,7 @@ impl<'a> CommittableColumn<'a> { impl<'a> From<&CommittableColumn<'a>> for ColumnType { fn from(value: &CommittableColumn<'a>) -> Self { match value { + CommittableColumn::TinyInt(_) => ColumnType::TinyInt, CommittableColumn::SmallInt(_) => ColumnType::SmallInt, CommittableColumn::Int(_) => ColumnType::Int, CommittableColumn::BigInt(_) => ColumnType::BigInt, @@ -101,6 +105,7 @@ impl<'a, S: Scalar> From<&Column<'a, S>> for CommittableColumn<'a> { fn from(value: &Column<'a, S>) -> Self { match value { Column::Boolean(bools) => CommittableColumn::Boolean(bools), + Column::TinyInt(ints) => CommittableColumn::TinyInt(ints), Column::SmallInt(ints) => CommittableColumn::SmallInt(ints), Column::Int(ints) => CommittableColumn::Int(ints), Column::BigInt(ints) => CommittableColumn::BigInt(ints), @@ -129,6 +134,7 @@ impl<'a, S: Scalar> From<&'a OwnedColumn> for CommittableColumn<'a> { fn from(value: &'a OwnedColumn) -> Self { match value { OwnedColumn::Boolean(bools) => CommittableColumn::Boolean(bools), + OwnedColumn::TinyInt(ints) => (ints as &[_]).into(), OwnedColumn::SmallInt(ints) => (ints as &[_]).into(), OwnedColumn::Int(ints) => (ints as &[_]).into(), OwnedColumn::BigInt(ints) => (ints as &[_]).into(), @@ -162,6 +168,11 @@ impl<'a> From<&'a [u8]> for CommittableColumn<'a> { CommittableColumn::RangeCheckWord(value) } } +impl<'a> From<&'a [i8]> for CommittableColumn<'a> { + fn from(value: &'a [i8]) -> Self { + CommittableColumn::TinyInt(value) + } +} impl<'a> From<&'a [i16]> for CommittableColumn<'a> { fn from(value: &'a [i16]) -> Self { CommittableColumn::SmallInt(value) @@ -199,6 +210,7 @@ impl<'a> From<&'a [bool]> for CommittableColumn<'a> { impl<'a, 'b> From<&'a CommittableColumn<'b>> for Sequence<'a> { fn from(value: &'a CommittableColumn<'b>) -> Self { match value { + CommittableColumn::TinyInt(ints) => Sequence::from(*ints), CommittableColumn::SmallInt(ints) => Sequence::from(*ints), CommittableColumn::Int(ints) => Sequence::from(*ints), CommittableColumn::BigInt(ints) => Sequence::from(*ints), @@ -267,6 +279,26 @@ mod tests { ); } + #[test] + fn we_can_get_type_and_length_of_tinyint_column() { + // empty case + let tinyint_committable_column = CommittableColumn::TinyInt(&[]); + assert_eq!(tinyint_committable_column.len(), 0); + assert!(tinyint_committable_column.is_empty()); + assert_eq!( + tinyint_committable_column.column_type(), + ColumnType::TinyInt + ); + + let tinyint_committable_column = CommittableColumn::TinyInt(&[12, 34, 56]); + assert_eq!(tinyint_committable_column.len(), 3); + assert!(!tinyint_committable_column.is_empty()); + assert_eq!( + tinyint_committable_column.column_type(), + ColumnType::TinyInt + ); + } + #[test] fn we_can_get_type_and_length_of_smallint_column() { // empty case @@ -461,6 +493,21 @@ mod tests { ); } + #[test] + fn we_can_convert_from_borrowing_tinyint_column() { + // empty case + let from_borrowed_column = + CommittableColumn::from(&Column::::TinyInt(&[])); + assert_eq!(from_borrowed_column, CommittableColumn::TinyInt(&[])); + + let from_borrowed_column = + CommittableColumn::from(&Column::::TinyInt(&[12, 34, 56])); + assert_eq!( + from_borrowed_column, + CommittableColumn::TinyInt(&[12, 34, 56]) + ); + } + #[test] fn we_can_convert_from_borrowing_smallint_column() { // empty case @@ -585,6 +632,18 @@ mod tests { assert_eq!(from_owned_column, CommittableColumn::BigInt(&[12, 34, 56])); } + #[test] + fn we_can_convert_from_owned_tinyint_column() { + // empty case + let owned_column = OwnedColumn::::TinyInt(Vec::new()); + let from_owned_column = CommittableColumn::from(&owned_column); + assert_eq!(from_owned_column, CommittableColumn::TinyInt(&[])); + + let owned_column = OwnedColumn::::TinyInt(vec![12, 34, 56]); + let from_owned_column = CommittableColumn::from(&owned_column); + assert_eq!(from_owned_column, CommittableColumn::TinyInt(&[12, 34, 56])); + } + #[test] fn we_can_convert_from_owned_smallint_column() { // empty case @@ -750,6 +809,30 @@ mod tests { assert_eq!(commitment_buffer[0], commitment_buffer[1]); } + #[test] + fn we_can_commit_to_tinyint_column_through_committable_column() { + // empty case + let committable_column = CommittableColumn::TinyInt(&[]); + let sequence = Sequence::from(&committable_column); + let mut commitment_buffer = [CompressedRistretto::default()]; + compute_curve25519_commitments(&mut commitment_buffer, &[sequence], 0); + assert_eq!(commitment_buffer[0], CompressedRistretto::default()); + + // nonempty case + let values = [12, 34, 56]; + let committable_column = CommittableColumn::TinyInt(&values); + + let sequence_actual = Sequence::from(&committable_column); + let sequence_expected = Sequence::from(values.as_slice()); + let mut commitment_buffer = [CompressedRistretto::default(); 2]; + compute_curve25519_commitments( + &mut commitment_buffer, + &[sequence_actual, sequence_expected], + 0, + ); + assert_eq!(commitment_buffer[0], commitment_buffer[1]); + } + #[test] fn we_can_commit_to_smallint_column_through_committable_column() { // empty case diff --git a/crates/proof-of-sql/src/base/commitment/naive_commitment.rs b/crates/proof-of-sql/src/base/commitment/naive_commitment.rs index a2fc78480..68bbc42f1 100644 --- a/crates/proof-of-sql/src/base/commitment/naive_commitment.rs +++ b/crates/proof-of-sql/src/base/commitment/naive_commitment.rs @@ -121,6 +121,9 @@ impl Commitment for NaiveCommitment { CommittableColumn::Boolean(bool_vec) => { bool_vec.iter().map(|b| b.into()).collect() } + CommittableColumn::TinyInt(tiny_int_vec) => { + tiny_int_vec.iter().map(|b| b.into()).collect() + } CommittableColumn::SmallInt(small_int_vec) => { small_int_vec.iter().map(|b| b.into()).collect() } diff --git a/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs b/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs index a3e5e1fd2..f7ab9ac0a 100644 --- a/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs +++ b/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs @@ -3,7 +3,7 @@ use crate::base::{database::Column, math::decimal::Precision, scalar::Scalar}; use arrow::{ array::{ Array, ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, - Int64Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, datatypes::{i256, DataType, TimeUnit as ArrowTimeUnit}, @@ -232,6 +232,15 @@ impl ArrayRefExt for ArrayRef { }) } } + DataType::Int8 => { + if let Some(array) = self.as_any().downcast_ref::() { + Ok(Column::TinyInt(&array.values()[range.start..range.end])) + } else { + Err(ArrowArrayToColumnConversionError::UnsupportedType { + datatype: self.data_type().clone(), + }) + } + } DataType::Int16 => { if let Some(array) = self.as_any().downcast_ref::() { Ok(Column::SmallInt(&array.values()[range.start..range.end])) @@ -728,6 +737,14 @@ mod tests { )); } + #[test] + fn we_can_convert_int8_array_normal_range() { + let alloc = Bump::new(); + let array: ArrayRef = Arc::new(Int8Array::from(vec![1, -3, 42])); + let result = array.to_column::(&alloc, &(1..3), None); + assert_eq!(result.unwrap(), Column::TinyInt(&[-3, 42])); + } + #[test] fn we_can_convert_int16_array_normal_range() { let alloc = Bump::new(); @@ -736,6 +753,14 @@ mod tests { assert_eq!(result.unwrap(), Column::SmallInt(&[-3, 42])); } + #[test] + fn we_can_convert_int8_array_empty_range() { + let alloc = Bump::new(); + let array: ArrayRef = Arc::new(Int8Array::from(vec![1, -3, 42])); + let result = array.to_column::(&alloc, &(1..1), None); + assert_eq!(result.unwrap(), Column::TinyInt(&[])); + } + #[test] fn we_can_convert_int16_array_empty_range() { let alloc = Bump::new(); @@ -744,6 +769,19 @@ mod tests { assert_eq!(result.unwrap(), Column::SmallInt(&[])); } + #[test] + fn we_cannot_convert_int8_array_oob_range() { + let alloc = Bump::new(); + let array: ArrayRef = Arc::new(Int8Array::from(vec![1, -3, 42])); + + let result = array.to_column::(&alloc, &(2..4), None); + + assert_eq!( + result, + Err(ArrowArrayToColumnConversionError::IndexOutOfBounds { len: 3, index: 4 }) + ); + } + #[test] fn we_cannot_convert_int16_array_oob_range() { let alloc = Bump::new(); @@ -757,6 +795,17 @@ mod tests { ); } + #[test] + fn we_can_convert_int8_array_with_nulls() { + let alloc = Bump::new(); + let array: ArrayRef = Arc::new(Int8Array::from(vec![Some(1), None, Some(42)])); + let result = array.to_column::(&alloc, &(0..3), None); + assert!(matches!( + result, + Err(ArrowArrayToColumnConversionError::ArrayContainsNulls) + )); + } + #[test] fn we_can_convert_int16_array_with_nulls() { let alloc = Bump::new(); @@ -812,6 +861,13 @@ mod tests { fn we_cannot_index_on_oob_range() { let alloc = Bump::new(); + let array0: ArrayRef = Arc::new(arrow::array::Int8Array::from(vec![1, -3])); + let result0 = array0.to_column::(&alloc, &(2..3), None); + assert_eq!( + result0, + Err(ArrowArrayToColumnConversionError::IndexOutOfBounds { len: 2, index: 3 }) + ); + let array1: ArrayRef = Arc::new(arrow::array::Int16Array::from(vec![1, -3])); let result1 = array1.to_column::(&alloc, &(2..3), None); assert_eq!( @@ -838,6 +894,13 @@ mod tests { fn we_cannot_index_on_empty_oob_range() { let alloc = Bump::new(); + let array0: ArrayRef = Arc::new(arrow::array::Int8Array::from(vec![1, -3])); + let result0 = array0.to_column::(&alloc, &(5..5), None); + assert_eq!( + result0, + Err(ArrowArrayToColumnConversionError::IndexOutOfBounds { len: 2, index: 5 }) + ); + let array1: ArrayRef = Arc::new(arrow::array::Int16Array::from(vec![1, -3])); let result1 = array1.to_column::(&alloc, &(5..5), None); assert_eq!( @@ -870,6 +933,16 @@ mod tests { assert_eq!(result, Column::Boolean(&[])); } + #[test] + fn we_can_build_an_empty_column_from_an_empty_range_int8() { + let alloc = Bump::new(); + let array: ArrayRef = Arc::new(arrow::array::Int8Array::from(vec![1, -3])); + let result = array + .to_column::(&alloc, &(2..2), None) + .unwrap(); + assert_eq!(result, Column::TinyInt(&[])); + } + #[test] fn we_can_build_an_empty_column_from_an_empty_range_int16() { let alloc = Bump::new(); @@ -958,6 +1031,14 @@ mod tests { #[test] fn we_can_convert_valid_integer_array_refs_into_valid_columns() { let alloc = Bump::new(); + let array: ArrayRef = Arc::new(arrow::array::Int8Array::from(vec![1, -3])); + assert_eq!( + array + .to_column::(&alloc, &(0..2), None) + .unwrap(), + Column::TinyInt(&[1, -3]) + ); + let array: ArrayRef = Arc::new(arrow::array::Int16Array::from(vec![1, -3])); assert_eq!( array @@ -1046,6 +1127,14 @@ mod tests { { let alloc = Bump::new(); + let array: ArrayRef = Arc::new(arrow::array::Int8Array::from(vec![0, 1, 127])); + assert_eq!( + array + .to_column::(&alloc, &(1..3), None) + .unwrap(), + Column::TinyInt(&[1, 127]) + ); + let array: ArrayRef = Arc::new(arrow::array::Int16Array::from(vec![0, 1, 545])); assert_eq!( array diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index efe46ed06..78833c37c 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -30,6 +30,8 @@ use serde::{Deserialize, Serialize}; pub enum Column<'a, S: Scalar> { /// Boolean columns Boolean(&'a [bool]), + /// i8 columns + TinyInt(&'a [i8]), /// i16 columns SmallInt(&'a [i16]), /// i32 columns @@ -60,6 +62,7 @@ impl<'a, S: Scalar> Column<'a, S> { pub fn column_type(&self) -> ColumnType { match self { Self::Boolean(_) => ColumnType::Boolean, + Self::TinyInt(_) => ColumnType::TinyInt, Self::SmallInt(_) => ColumnType::SmallInt, Self::Int(_) => ColumnType::Int, Self::BigInt(_) => ColumnType::BigInt, @@ -79,6 +82,7 @@ impl<'a, S: Scalar> Column<'a, S> { pub fn len(&self) -> usize { match self { Self::Boolean(col) => col.len(), + Self::TinyInt(col) => col.len(), Self::SmallInt(col) => col.len(), Self::Int(col) => col.len(), Self::BigInt(col) | Self::TimestampTZ(_, _, col) => col.len(), @@ -106,6 +110,9 @@ impl<'a, S: Scalar> Column<'a, S> { LiteralValue::Boolean(value) => { Column::Boolean(alloc.alloc_slice_fill_copy(length, *value)) } + LiteralValue::TinyInt(value) => { + Column::TinyInt(alloc.alloc_slice_fill_copy(length, *value)) + } LiteralValue::SmallInt(value) => { Column::SmallInt(alloc.alloc_slice_fill_copy(length, *value)) } @@ -138,6 +145,7 @@ impl<'a, S: Scalar> Column<'a, S> { pub fn from_owned_column(owned_column: &'a OwnedColumn, alloc: &'a Bump) -> Self { match owned_column { OwnedColumn::Boolean(col) => Column::Boolean(col.as_slice()), + OwnedColumn::TinyInt(col) => Column::TinyInt(col.as_slice()), OwnedColumn::SmallInt(col) => Column::SmallInt(col.as_slice()), OwnedColumn::Int(col) => Column::Int(col.as_slice()), OwnedColumn::BigInt(col) => Column::BigInt(col.as_slice()), @@ -173,6 +181,7 @@ impl<'a, S: Scalar> Column<'a, S> { pub(crate) fn as_scalar(&self, alloc: &'a Bump) -> &'a [S] { match self { Self::Boolean(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), + Self::TinyInt(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), Self::SmallInt(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), Self::Int(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), Self::BigInt(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), @@ -191,6 +200,7 @@ impl<'a, S: Scalar> Column<'a, S> { pub(crate) fn scalar_at(&self, index: usize) -> Option { (index < self.len()).then_some(match self { Self::Boolean(col) => S::from(col[index]), + Self::TinyInt(col) => S::from(col[index]), Self::SmallInt(col) => S::from(col[index]), Self::Int(col) => S::from(col[index]), Self::BigInt(col) | Self::TimestampTZ(_, _, col) => S::from(col[index]), @@ -208,6 +218,7 @@ impl<'a, S: Scalar> Column<'a, S> { Self::Boolean(col) => slice_cast_with(col, |b| S::from(b) * scale_factor), Self::Decimal75(_, _, col) => slice_cast_with(col, |s| *s * scale_factor), Self::VarChar((_, values)) => slice_cast_with(values, |s| *s * scale_factor), + Self::TinyInt(col) => slice_cast_with(col, |i| S::from(i) * scale_factor), Self::SmallInt(col) => slice_cast_with(col, |i| S::from(i) * scale_factor), Self::Int(col) => slice_cast_with(col, |i| S::from(i) * scale_factor), Self::BigInt(col) => slice_cast_with(col, |i| S::from(i) * scale_factor), @@ -228,6 +239,9 @@ pub enum ColumnType { /// Mapped to bool #[serde(alias = "BOOLEAN", alias = "boolean")] Boolean, + /// Mapped to i8 + #[serde(alias = "TINYINT", alias = "tinyint")] + TinyInt, /// Mapped to i16 #[serde(alias = "SMALLINT", alias = "smallint")] SmallInt, @@ -260,7 +274,8 @@ impl ColumnType { pub fn is_numeric(&self) -> bool { matches!( self, - ColumnType::SmallInt + ColumnType::TinyInt + | ColumnType::SmallInt | ColumnType::Int | ColumnType::BigInt | ColumnType::Int128 @@ -274,13 +289,18 @@ impl ColumnType { pub fn is_integer(&self) -> bool { matches!( self, - ColumnType::SmallInt | ColumnType::Int | ColumnType::BigInt | ColumnType::Int128 + ColumnType::TinyInt + | ColumnType::SmallInt + | ColumnType::Int + | ColumnType::BigInt + | ColumnType::Int128 ) } /// Returns the number of bits in the integer type if it is an integer type. Otherwise, return None. fn to_integer_bits(self) -> Option { match self { + ColumnType::TinyInt => Some(8), ColumnType::SmallInt => Some(16), ColumnType::Int => Some(32), ColumnType::BigInt => Some(64), @@ -294,6 +314,7 @@ impl ColumnType { /// Otherwise, return None. fn from_integer_bits(bits: usize) -> Option { match bits { + 8 => Some(ColumnType::TinyInt), 16 => Some(ColumnType::SmallInt), 32 => Some(ColumnType::Int), 64 => Some(ColumnType::BigInt), @@ -322,6 +343,7 @@ impl ColumnType { #[must_use] pub fn precision_value(&self) -> Option { match self { + Self::TinyInt => Some(3_u8), Self::SmallInt => Some(5_u8), Self::Int => Some(10_u8), Self::BigInt | Self::TimestampTZ(_, _) => Some(19_u8), @@ -338,7 +360,12 @@ impl ColumnType { pub fn scale(&self) -> Option { match self { Self::Decimal75(_, scale) => Some(*scale), - Self::SmallInt | Self::Int | Self::BigInt | Self::Int128 | Self::Scalar => Some(0), + Self::TinyInt + | Self::SmallInt + | Self::Int + | Self::BigInt + | Self::Int128 + | Self::Scalar => Some(0), Self::Boolean | Self::VarChar => None, Self::TimestampTZ(tu, _) => match tu { PoSQLTimeUnit::Second => Some(0), @@ -354,6 +381,7 @@ impl ColumnType { pub fn byte_size(&self) -> usize { match self { Self::Boolean => size_of::(), + Self::TinyInt => size_of::(), Self::SmallInt => size_of::(), Self::Int => size_of::(), Self::BigInt | Self::TimestampTZ(_, _) => size_of::(), @@ -372,9 +400,12 @@ impl ColumnType { #[must_use] pub const fn is_signed(&self) -> bool { match self { - Self::SmallInt | Self::Int | Self::BigInt | Self::Int128 | Self::TimestampTZ(_, _) => { - true - } + Self::TinyInt + | Self::SmallInt + | Self::Int + | Self::BigInt + | Self::Int128 + | Self::TimestampTZ(_, _) => true, Self::Decimal75(_, _) | Self::Scalar | Self::VarChar | Self::Boolean => false, } } @@ -386,6 +417,7 @@ impl From<&ColumnType> for DataType { fn from(column_type: &ColumnType) -> Self { match column_type { ColumnType::Boolean => DataType::Boolean, + ColumnType::TinyInt => DataType::Int8, ColumnType::SmallInt => DataType::Int16, ColumnType::Int => DataType::Int32, ColumnType::BigInt => DataType::Int64, @@ -417,6 +449,7 @@ impl TryFrom for ColumnType { fn try_from(data_type: DataType) -> Result { match data_type { DataType::Boolean => Ok(ColumnType::Boolean), + DataType::Int8 => Ok(ColumnType::TinyInt), DataType::Int16 => Ok(ColumnType::SmallInt), DataType::Int32 => Ok(ColumnType::Int), DataType::Int64 => Ok(ColumnType::BigInt), @@ -447,6 +480,7 @@ impl Display for ColumnType { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { ColumnType::Boolean => write!(f, "BOOLEAN"), + ColumnType::TinyInt => write!(f, "TINYINT"), ColumnType::SmallInt => write!(f, "SMALLINT"), ColumnType::Int => write!(f, "INT"), ColumnType::BigInt => write!(f, "BIGINT"), @@ -566,6 +600,10 @@ mod tests { let serialized = serde_json::to_string(&column_type).unwrap(); assert_eq!(serialized, r#""Boolean""#); + let column_type = ColumnType::TinyInt; + let serialized = serde_json::to_string(&column_type).unwrap(); + assert_eq!(serialized, r#""TinyInt""#); + let column_type = ColumnType::SmallInt; let serialized = serde_json::to_string(&column_type).unwrap(); assert_eq!(serialized, r#""SmallInt""#); @@ -607,6 +645,10 @@ mod tests { let deserialized: ColumnType = serde_json::from_str(r#""Boolean""#).unwrap(); assert_eq!(deserialized, expected_column_type); + let expected_column_type = ColumnType::TinyInt; + let deserialized: ColumnType = serde_json::from_str(r#""TinyInt""#).unwrap(); + assert_eq!(deserialized, expected_column_type); + let expected_column_type = ColumnType::SmallInt; let deserialized: ColumnType = serde_json::from_str(r#""SmallInt""#).unwrap(); assert_eq!(deserialized, expected_column_type); @@ -619,6 +661,10 @@ mod tests { let deserialized: ColumnType = serde_json::from_str(r#""BigInt""#).unwrap(); assert_eq!(deserialized, expected_column_type); + let expected_column_type = ColumnType::TinyInt; + let deserialized: ColumnType = serde_json::from_str(r#""TINYINT""#).unwrap(); + assert_eq!(deserialized, expected_column_type); + let expected_column_type = ColumnType::SmallInt; let deserialized: ColumnType = serde_json::from_str(r#""SMALLINT""#).unwrap(); assert_eq!(deserialized, expected_column_type); @@ -672,6 +718,14 @@ mod tests { serde_json::from_str::(r#""BIGINT""#).unwrap(), ColumnType::BigInt ); + assert_eq!( + serde_json::from_str::(r#""TINYINT""#).unwrap(), + ColumnType::TinyInt + ); + assert_eq!( + serde_json::from_str::(r#""tinyint""#).unwrap(), + ColumnType::TinyInt + ); assert_eq!( serde_json::from_str::(r#""SMALLINT""#).unwrap(), ColumnType::SmallInt @@ -739,6 +793,9 @@ mod tests { let deserialized: Result = serde_json::from_str(r#""BooLean""#); assert!(deserialized.is_err()); + let deserialized: Result = serde_json::from_str(r#""Tinyint""#); + assert!(deserialized.is_err()); + let deserialized: Result = serde_json::from_str(r#""Smallint""#); assert!(deserialized.is_err()); @@ -775,6 +832,14 @@ mod tests { boolean ); + let tinyint = ColumnType::TinyInt; + let tinyint_json = serde_json::to_string(&tinyint).unwrap(); + assert_eq!(tinyint_json, "\"TinyInt\""); + assert_eq!( + serde_json::from_str::(&tinyint_json).unwrap(), + tinyint + ); + let smallint = ColumnType::SmallInt; let smallint_json = serde_json::to_string(&smallint).unwrap(); assert_eq!(smallint_json, "\"SmallInt\""); @@ -845,6 +910,10 @@ mod tests { assert_eq!(column.len(), 3); assert!(!column.is_empty()); + let column = Column::::TinyInt(&[1, 2, 3]); + assert_eq!(column.len(), 3); + assert!(!column.is_empty()); + let column = Column::::SmallInt(&[1, 2, 3]); assert_eq!(column.len(), 3); assert!(!column.is_empty()); @@ -885,6 +954,10 @@ mod tests { assert_eq!(column.len(), 0); assert!(column.is_empty()); + let column = Column::::TinyInt(&[]); + assert_eq!(column.len(), 0); + assert!(column.is_empty()); + let column = Column::::SmallInt(&[]); assert_eq!(column.len(), 0); assert!(column.is_empty()); @@ -968,6 +1041,10 @@ mod tests { assert_eq!(column.column_type().byte_size(), 1); assert_eq!(column.column_type().bit_size(), 8); + let column = Column::::TinyInt(&[1, 2, 3, 4]); + assert_eq!(column.column_type().byte_size(), 1); + assert_eq!(column.column_type().bit_size(), 8); + let column = Column::::SmallInt(&[1, 2, 3, 4]); assert_eq!(column.column_type().byte_size(), 2); assert_eq!(column.column_type().bit_size(), 16); diff --git a/crates/proof-of-sql/src/base/database/column_operation.rs b/crates/proof-of-sql/src/base/database/column_operation.rs index f7a82a0f0..a9da45987 100644 --- a/crates/proof-of-sql/src/base/database/column_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_operation.rs @@ -932,6 +932,12 @@ mod test { #[test] fn we_can_add_numeric_types() { // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); @@ -939,6 +945,12 @@ mod test { assert_eq!(expected, actual); // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); @@ -946,6 +958,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Scalar; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); @@ -953,6 +971,12 @@ mod test { assert_eq!(expected, actual); // lhs is a decimal with nonnegative scale and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); + assert_eq!(expected, actual); + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); @@ -967,6 +991,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); @@ -991,6 +1021,13 @@ mod test { #[test] fn we_cannot_add_non_numeric_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( @@ -1030,6 +1067,12 @@ mod test { #[test] fn we_can_subtract_numeric_types() { // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); @@ -1037,6 +1080,12 @@ mod test { assert_eq!(expected, actual); // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); @@ -1044,6 +1093,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Scalar; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); @@ -1051,6 +1106,12 @@ mod test { assert_eq!(expected, actual); // lhs is a decimal and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); + assert_eq!(expected, actual); + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); @@ -1065,6 +1126,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); @@ -1089,6 +1156,13 @@ mod test { #[test] fn we_cannot_subtract_non_numeric_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( @@ -1128,6 +1202,12 @@ mod test { #[test] fn we_can_multiply_numeric_types() { // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; let actual = try_multiply_column_types(lhs, rhs).unwrap(); @@ -1135,6 +1215,12 @@ mod test { assert_eq!(expected, actual); // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; let actual = try_multiply_column_types(lhs, rhs).unwrap(); @@ -1142,6 +1228,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Scalar; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; let actual = try_multiply_column_types(lhs, rhs).unwrap(); @@ -1149,6 +1241,12 @@ mod test { assert_eq!(expected, actual); // lhs is a decimal and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 2); + assert_eq!(expected, actual); + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; let actual = try_multiply_column_types(lhs, rhs).unwrap(); @@ -1163,6 +1261,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), -2); + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); let actual = try_multiply_column_types(lhs, rhs).unwrap(); @@ -1187,6 +1291,13 @@ mod test { #[test] fn we_cannot_multiply_non_numeric_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( @@ -1246,6 +1357,12 @@ mod test { #[test] fn we_can_divide_numeric_types() { // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; let actual = try_divide_column_types(lhs, rhs).unwrap(); @@ -1253,6 +1370,12 @@ mod test { assert_eq!(expected, actual); // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; let actual = try_divide_column_types(lhs, rhs).unwrap(); @@ -1260,6 +1383,12 @@ mod test { assert_eq!(expected, actual); // lhs is a decimal with nonnegative scale and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 6); + assert_eq!(expected, actual); + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; let actual = try_divide_column_types(lhs, rhs).unwrap(); @@ -1267,6 +1396,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with nonnegative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 11); + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let actual = try_divide_column_types(lhs, rhs).unwrap(); @@ -1281,6 +1416,12 @@ mod test { assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(12).unwrap(), 11); + assert_eq!(expected, actual); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); let actual = try_divide_column_types(lhs, rhs).unwrap(); @@ -1305,6 +1446,13 @@ mod test { #[test] fn we_cannot_divide_non_numeric_or_scalar_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( @@ -1417,6 +1565,17 @@ mod test { #[test] fn we_can_eq_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [1_i8, -2, 3]; + let rhs = [100_i8, 5, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = eq_decimal_columns(&lhs, &rhs, left_column_type, right_column_type); + let expected = vec![true, false, false]; + assert_eq!(expected, actual); + let lhs = [1_i16, -2, 3]; let rhs = [100_i16, 5, -2] .into_iter() @@ -1538,6 +1697,17 @@ mod test { #[test] fn we_can_le_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [1_i8, -2, 3]; + let rhs = [100_i8, 5, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = le_decimal_columns(&lhs, &rhs, left_column_type, right_column_type); + let expected = vec![true, true, false]; + assert_eq!(expected, actual); + let lhs = [1_i16, -2, 3]; let rhs = [100_i16, 5, -2] .into_iter() @@ -1659,6 +1829,17 @@ mod test { #[test] fn we_can_ge_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [1_i8, -2, 3]; + let rhs = [100_i8, 5, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = ge_decimal_columns(&lhs, &rhs, left_column_type, right_column_type); + let expected = vec![true, false, true]; + assert_eq!(expected, actual); + let lhs = [1_i16, -2, 3]; let rhs = [100_i16, 5, -2] .into_iter() @@ -1800,6 +1981,23 @@ mod test { #[test] fn we_can_try_add_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [1_i8, -2, 3]; + let rhs = [4_i8, 5, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual: (Precision, i8, Vec) = + try_add_decimal_columns(&lhs, &rhs, left_column_type, right_column_type).unwrap(); + let expected_scalars = vec![ + Curve25519Scalar::from(104), + Curve25519Scalar::from(-195), + Curve25519Scalar::from(298), + ]; + let expected = (Precision::new(11).unwrap(), 2, expected_scalars); + assert_eq!(expected, actual); + let lhs = [1_i16, -2, 3]; let rhs = [4_i16, 5, -2] .into_iter() @@ -1962,6 +2160,23 @@ mod test { #[test] fn we_can_try_subtract_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [1_i8, -2, 3]; + let rhs = [4_i8, 5, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual: (Precision, i8, Vec) = + try_subtract_decimal_columns(&lhs, &rhs, left_column_type, right_column_type).unwrap(); + let expected_scalars = vec![ + Curve25519Scalar::from(96), + Curve25519Scalar::from(-205), + Curve25519Scalar::from(302), + ]; + let expected = (Precision::new(11).unwrap(), 2, expected_scalars); + assert_eq!(expected, actual); + let lhs = [1_i16, -2, 3]; let rhs = [4_i16, 5, -2] .into_iter() @@ -2105,6 +2320,23 @@ mod test { #[test] fn we_can_try_multiply_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [1_i8, -2, 3]; + let rhs = [4_i8, 5, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual: (Precision, i8, Vec) = + try_multiply_decimal_columns(&lhs, &rhs, left_column_type, right_column_type).unwrap(); + let expected_scalars = vec![ + Curve25519Scalar::from(4), + Curve25519Scalar::from(-10), + Curve25519Scalar::from(-6), + ]; + let expected = (Precision::new(14).unwrap(), 2, expected_scalars); + assert_eq!(expected, actual); + let lhs = [1_i16, -2, 3]; let rhs = [4_i16, 5, -2] .into_iter() @@ -2268,6 +2500,23 @@ mod test { #[test] fn we_can_try_divide_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale + let lhs = [0_i8, 2, 3]; + let rhs = [4_i8, 5, 2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let left_column_type = ColumnType::TinyInt; + let right_column_type = ColumnType::Decimal75(Precision::new(3).unwrap(), 2); + let actual: (Precision, i8, Vec) = + try_divide_decimal_columns(&lhs, &rhs, left_column_type, right_column_type).unwrap(); + let expected_scalars = vec![ + Curve25519Scalar::from(0_i64), + Curve25519Scalar::from(40_000_000_i64), + Curve25519Scalar::from(150_000_000_i64), + ]; + let expected = (Precision::new(11).unwrap(), 6, expected_scalars); + assert_eq!(expected, actual); + let lhs = [0_i16, 2, 3]; let rhs = [4_i16, 5, 2] .into_iter() @@ -2286,6 +2535,23 @@ mod test { assert_eq!(expected, actual); // lhs is decimal with negative scale and rhs is integer + let lhs = [4_i8, 15, -2] + .into_iter() + .map(Curve25519Scalar::from) + .collect::>(); + let rhs = [71_i64, -82, 23]; + let left_column_type = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let right_column_type = ColumnType::TinyInt; + let actual: (Precision, i8, Vec) = + try_divide_decimal_columns(&lhs, &rhs, left_column_type, right_column_type).unwrap(); + let expected_scalars = vec![ + Curve25519Scalar::from(5_633_802), + Curve25519Scalar::from(-18_292_682), + Curve25519Scalar::from(-8_695_652), + ]; + let expected = (Precision::new(18).unwrap(), 6, expected_scalars); + assert_eq!(expected, actual); + let lhs = [4_i16, 15, -2] .into_iter() .map(Curve25519Scalar::from) diff --git a/crates/proof-of-sql/src/base/database/filter_util.rs b/crates/proof-of-sql/src/base/database/filter_util.rs index 898ee2b21..f775dbd3b 100644 --- a/crates/proof-of-sql/src/base/database/filter_util.rs +++ b/crates/proof-of-sql/src/base/database/filter_util.rs @@ -44,6 +44,9 @@ pub fn filter_column_by_index<'a, S: Scalar>( Column::Boolean(col) => { Column::Boolean(alloc.alloc_slice_fill_iter(indexes.iter().map(|&i| col[i]))) } + Column::TinyInt(col) => { + Column::TinyInt(alloc.alloc_slice_fill_iter(indexes.iter().map(|&i| col[i]))) + } Column::SmallInt(col) => { Column::SmallInt(alloc.alloc_slice_fill_iter(indexes.iter().map(|&i| col[i]))) } diff --git a/crates/proof-of-sql/src/base/database/group_by_util.rs b/crates/proof-of-sql/src/base/database/group_by_util.rs index 68865c9cb..439e110ee 100644 --- a/crates/proof-of-sql/src/base/database/group_by_util.rs +++ b/crates/proof-of-sql/src/base/database/group_by_util.rs @@ -140,6 +140,7 @@ pub(crate) fn sum_aggregate_column_by_index_counts<'a, S: Scalar>( indexes: &[usize], ) -> &'a [S] { match column { + Column::TinyInt(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::SmallInt(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::Int(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::BigInt(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes), @@ -168,6 +169,7 @@ pub(crate) fn max_aggregate_column_by_index_counts<'a, S: Scalar>( ) -> &'a [Option] { match column { Column::Boolean(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::TinyInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::SmallInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::Int(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::BigInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), @@ -199,6 +201,7 @@ pub(crate) fn min_aggregate_column_by_index_counts<'a, S: Scalar>( ) -> &'a [Option] { match column { Column::Boolean(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::TinyInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::SmallInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::Int(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), Column::BigInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), @@ -352,6 +355,7 @@ pub(crate) fn compare_indexes_by_columns( .iter() .map(|col| match col { Column::Boolean(col) => col[i].cmp(&col[j]), + Column::TinyInt(col) => col[i].cmp(&col[j]), Column::SmallInt(col) => col[i].cmp(&col[j]), Column::Int(col) => col[i].cmp(&col[j]), Column::BigInt(col) | Column::TimestampTZ(_, _, col) => col[i].cmp(&col[j]), @@ -377,6 +381,7 @@ pub(crate) fn compare_indexes_by_owned_columns( .iter() .map(|col| match col { OwnedColumn::Boolean(col) => col[i].cmp(&col[j]), + OwnedColumn::TinyInt(col) => col[i].cmp(&col[j]), OwnedColumn::SmallInt(col) => col[i].cmp(&col[j]), OwnedColumn::Int(col) => col[i].cmp(&col[j]), OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => col[i].cmp(&col[j]), diff --git a/crates/proof-of-sql/src/base/database/literal_value.rs b/crates/proof-of-sql/src/base/database/literal_value.rs index 6e78099c3..59888d672 100644 --- a/crates/proof-of-sql/src/base/database/literal_value.rs +++ b/crates/proof-of-sql/src/base/database/literal_value.rs @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize}; pub enum LiteralValue { /// Boolean literals Boolean(bool), + /// i8 literals + TinyInt(i8), /// i16 literals SmallInt(i16), /// i32 literals @@ -41,6 +43,7 @@ impl LiteralValue { pub fn column_type(&self) -> ColumnType { match self { Self::Boolean(_) => ColumnType::Boolean, + Self::TinyInt(_) => ColumnType::TinyInt, Self::SmallInt(_) => ColumnType::SmallInt, Self::Int(_) => ColumnType::Int, Self::BigInt(_) => ColumnType::BigInt, @@ -56,6 +59,7 @@ impl LiteralValue { pub(crate) fn to_scalar(&self) -> S { match self { Self::Boolean(b) => b.into(), + Self::TinyInt(i) => i.into(), Self::SmallInt(i) => i.into(), Self::Int(i) => i.into(), Self::BigInt(i) => i.into(), diff --git a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs b/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs index 7c7e0ea82..579b9f362 100644 --- a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs @@ -26,7 +26,7 @@ use alloc::sync::Arc; use arrow::{ array::{ ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, - Int64Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, datatypes::{i256, DataType, Schema, SchemaRef, TimeUnit as ArrowTimeUnit}, @@ -86,6 +86,7 @@ impl From> for ArrayRef { fn from(value: OwnedColumn) -> Self { match value { OwnedColumn::Boolean(col) => Arc::new(BooleanArray::from(col)), + OwnedColumn::TinyInt(col) => Arc::new(Int8Array::from(col)), OwnedColumn::SmallInt(col) => Arc::new(Int16Array::from(col)), OwnedColumn::Int(col) => Arc::new(Int32Array::from(col)), OwnedColumn::BigInt(col) => Arc::new(Int64Array::from(col)), @@ -162,6 +163,14 @@ impl TryFrom<&ArrayRef> for OwnedColumn { .collect::>>() .ok_or(OwnedArrowConversionError::NullNotSupportedYet)?, )), + DataType::Int8 => Ok(Self::TinyInt( + value + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(), + )), DataType::Int16 => Ok(Self::SmallInt( value .as_any() diff --git a/crates/proof-of-sql/src/base/database/owned_column.rs b/crates/proof-of-sql/src/base/database/owned_column.rs index 0316ab552..700a443b2 100644 --- a/crates/proof-of-sql/src/base/database/owned_column.rs +++ b/crates/proof-of-sql/src/base/database/owned_column.rs @@ -26,6 +26,8 @@ use proof_of_sql_parser::{ pub enum OwnedColumn { /// Boolean columns Boolean(Vec), + /// i8 columns + TinyInt(Vec), /// i16 columns SmallInt(Vec), /// i32 columns @@ -50,6 +52,7 @@ impl OwnedColumn { pub fn len(&self) -> usize { match self { OwnedColumn::Boolean(col) => col.len(), + OwnedColumn::TinyInt(col) => col.len(), OwnedColumn::SmallInt(col) => col.len(), OwnedColumn::Int(col) => col.len(), OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => col.len(), @@ -63,6 +66,7 @@ impl OwnedColumn { pub fn try_permute(&self, permutation: &Permutation) -> Result { Ok(match self { OwnedColumn::Boolean(col) => OwnedColumn::Boolean(permutation.try_apply(col)?), + OwnedColumn::TinyInt(col) => OwnedColumn::TinyInt(permutation.try_apply(col)?), OwnedColumn::SmallInt(col) => OwnedColumn::SmallInt(permutation.try_apply(col)?), OwnedColumn::Int(col) => OwnedColumn::Int(permutation.try_apply(col)?), OwnedColumn::BigInt(col) => OwnedColumn::BigInt(permutation.try_apply(col)?), @@ -83,6 +87,7 @@ impl OwnedColumn { pub fn slice(&self, start: usize, end: usize) -> Self { match self { OwnedColumn::Boolean(col) => OwnedColumn::Boolean(col[start..end].to_vec()), + OwnedColumn::TinyInt(col) => OwnedColumn::TinyInt(col[start..end].to_vec()), OwnedColumn::SmallInt(col) => OwnedColumn::SmallInt(col[start..end].to_vec()), OwnedColumn::Int(col) => OwnedColumn::Int(col[start..end].to_vec()), OwnedColumn::BigInt(col) => OwnedColumn::BigInt(col[start..end].to_vec()), @@ -103,6 +108,7 @@ impl OwnedColumn { pub fn is_empty(&self) -> bool { match self { OwnedColumn::Boolean(col) => col.is_empty(), + OwnedColumn::TinyInt(col) => col.is_empty(), OwnedColumn::SmallInt(col) => col.is_empty(), OwnedColumn::Int(col) => col.is_empty(), OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => col.is_empty(), @@ -116,6 +122,7 @@ impl OwnedColumn { pub fn column_type(&self) -> ColumnType { match self { OwnedColumn::Boolean(_) => ColumnType::Boolean, + OwnedColumn::TinyInt(_) => ColumnType::TinyInt, OwnedColumn::SmallInt(_) => ColumnType::SmallInt, OwnedColumn::Int(_) => ColumnType::Int, OwnedColumn::BigInt(_) => ColumnType::BigInt, @@ -141,6 +148,15 @@ impl OwnedColumn { error: "Overflow in scalar conversions".to_string(), })?, )), + ColumnType::TinyInt => Ok(OwnedColumn::TinyInt( + scalars + .iter() + .map(|s| -> Result { TryInto::::try_into(*s) }) + .collect::, _>>() + .map_err(|_| OwnedColumnError::ScalarConversionError { + error: "Overflow in scalar conversions".to_string(), + })?, + )), ColumnType::SmallInt => Ok(OwnedColumn::SmallInt( scalars .iter() @@ -214,6 +230,15 @@ impl OwnedColumn { Self::try_from_scalars(&scalars, column_type) } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [i8], panicking if it is not. + pub fn i8_iter(&self) -> impl Iterator { + match self { + OwnedColumn::TinyInt(col) => col.iter(), + _ => panic!("Expected TinyInt column"), + } + } #[cfg(test)] /// Returns an iterator over the raw data of the column /// assuming the underlying type is [i16], panicking if it is not. @@ -283,6 +308,7 @@ impl<'a, S: Scalar> From<&Column<'a, S>> for OwnedColumn { fn from(col: &Column<'a, S>) -> Self { match col { Column::Boolean(col) => OwnedColumn::Boolean(col.to_vec()), + Column::TinyInt(col) => OwnedColumn::TinyInt(col.to_vec()), Column::SmallInt(col) => OwnedColumn::SmallInt(col.to_vec()), Column::Int(col) => OwnedColumn::Int(col.to_vec()), Column::BigInt(col) => OwnedColumn::BigInt(col.to_vec()), @@ -312,6 +338,7 @@ pub(crate) fn compare_indexes_by_owned_columns_with_direction( .map(|(col, direction)| { let ordering = match col { OwnedColumn::Boolean(col) => col[i].cmp(&col[j]), + OwnedColumn::TinyInt(col) => col[i].cmp(&col[j]), OwnedColumn::SmallInt(col) => col[i].cmp(&col[j]), OwnedColumn::Int(col) => col[i].cmp(&col[j]), OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => { diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index be1603396..427ed8e88 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -63,6 +63,31 @@ impl OwnedColumn { }); } match (self, rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + Ok(Self::Boolean(eq_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + ))) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), (Self::SmallInt(lhs), Self::Int(rhs)) => { Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) @@ -81,6 +106,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) } @@ -99,6 +128,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) } @@ -117,6 +150,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) } @@ -135,6 +172,15 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + Ok(Self::Boolean(eq_decimal_columns( + rhs_values, + lhs_values, + rhs.column_type(), + self.column_type(), + ))) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { Ok(Self::Boolean(eq_decimal_columns( rhs_values, @@ -198,6 +244,31 @@ impl OwnedColumn { }); } match (self, rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + Ok(Self::Boolean(le_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + ))) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), (Self::SmallInt(lhs), Self::Int(rhs)) => { Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) @@ -216,6 +287,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) } @@ -234,6 +309,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) } @@ -252,6 +331,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) } @@ -270,6 +353,15 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + Ok(Self::Boolean(ge_decimal_columns( + rhs_values, + lhs_values, + rhs.column_type(), + self.column_type(), + ))) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { Ok(Self::Boolean(ge_decimal_columns( rhs_values, @@ -332,6 +424,31 @@ impl OwnedColumn { }); } match (self, rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + Ok(Self::Boolean(ge_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + ))) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), (Self::SmallInt(lhs), Self::Int(rhs)) => { Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) @@ -350,6 +467,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) } @@ -368,6 +489,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) } @@ -386,6 +511,10 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) } @@ -404,6 +533,15 @@ impl OwnedColumn { rhs.column_type(), ))) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + Ok(Self::Boolean(le_decimal_columns( + rhs_values, + lhs_values, + rhs.column_type(), + self.column_type(), + ))) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { Ok(Self::Boolean(le_decimal_columns( rhs_values, @@ -469,6 +607,34 @@ impl Add for OwnedColumn { }); } match (&self, &rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::TinyInt(try_add_slices(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::SmallInt(try_add_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Int(try_add_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::BigInt(try_add_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Int128(try_add_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = try_add_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::SmallInt(try_add_slices_with_casting(rhs, lhs)?)) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::SmallInt(try_add_slices(lhs, rhs)?)) } @@ -490,6 +656,10 @@ impl Add for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int(try_add_slices_with_casting(rhs, lhs)?)) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int(try_add_slices_with_casting(rhs, lhs)?)) } @@ -509,6 +679,10 @@ impl Add for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::BigInt(try_add_slices_with_casting(rhs, lhs)?)) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::BigInt(try_add_slices_with_casting(rhs, lhs)?)) } @@ -528,6 +702,10 @@ impl Add for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int128(try_add_slices_with_casting(rhs, lhs)?)) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int128(try_add_slices_with_casting(rhs, lhs)?)) } @@ -547,6 +725,16 @@ impl Add for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = try_add_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { let (new_precision, new_scale, new_values) = try_add_decimal_columns( lhs_values, @@ -612,6 +800,34 @@ impl Sub for OwnedColumn { }); } match (&self, &rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::TinyInt(try_subtract_slices(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::SmallInt(try_subtract_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Int(try_subtract_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::BigInt(try_subtract_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Int128(try_subtract_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::SmallInt(try_subtract_slices_right_upcast(lhs, rhs)?)) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::SmallInt(try_subtract_slices(lhs, rhs)?)) } @@ -633,6 +849,10 @@ impl Sub for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int(try_subtract_slices_right_upcast(lhs, rhs)?)) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int(try_subtract_slices_right_upcast(lhs, rhs)?)) } @@ -652,6 +872,10 @@ impl Sub for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::BigInt(try_subtract_slices_right_upcast(lhs, rhs)?)) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::BigInt(try_subtract_slices_right_upcast(lhs, rhs)?)) } @@ -673,6 +897,10 @@ impl Sub for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int128(try_subtract_slices_right_upcast(lhs, rhs)?)) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int128(try_subtract_slices_right_upcast(lhs, rhs)?)) } @@ -694,6 +922,16 @@ impl Sub for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( lhs_values, @@ -759,6 +997,34 @@ impl Mul for OwnedColumn { }); } match (&self, &rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::TinyInt(try_multiply_slices(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::SmallInt(try_multiply_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Int(try_multiply_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::BigInt(try_multiply_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Int128(try_multiply_slices_with_casting(lhs, rhs)?)) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::SmallInt(try_multiply_slices_with_casting(rhs, lhs)?)) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::SmallInt(try_multiply_slices(lhs, rhs)?)) } @@ -780,6 +1046,10 @@ impl Mul for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int(try_multiply_slices_with_casting(rhs, lhs)?)) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int(try_multiply_slices_with_casting(rhs, lhs)?)) } @@ -799,6 +1069,10 @@ impl Mul for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::BigInt(try_multiply_slices_with_casting(rhs, lhs)?)) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::BigInt(try_multiply_slices_with_casting(rhs, lhs)?)) } @@ -820,6 +1094,10 @@ impl Mul for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int128(try_multiply_slices_with_casting(rhs, lhs)?)) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int128(try_multiply_slices_with_casting(rhs, lhs)?)) } @@ -841,6 +1119,16 @@ impl Mul for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( lhs_values, @@ -906,6 +1194,34 @@ impl Div for OwnedColumn { }); } match (&self, &rhs) { + (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::TinyInt(try_divide_slices(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { + Ok(Self::SmallInt(try_divide_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int(rhs)) => { + Ok(Self::Int(try_divide_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::BigInt(rhs)) => { + Ok(Self::BigInt(try_divide_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs), Self::Int128(rhs)) => { + Ok(Self::Int128(try_divide_slices_left_upcast(lhs, rhs)?)) + } + (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = try_divide_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } + + (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::SmallInt(try_divide_slices_right_upcast(lhs, rhs)?)) + } (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::SmallInt(try_divide_slices(lhs, rhs)?)) } @@ -927,6 +1243,10 @@ impl Div for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int(try_divide_slices_right_upcast(lhs, rhs)?)) + } (Self::Int(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int(try_divide_slices_right_upcast(lhs, rhs)?)) } @@ -946,6 +1266,10 @@ impl Div for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::BigInt(lhs), Self::TinyInt(rhs)) => { + Ok(Self::BigInt(try_divide_slices_right_upcast(lhs, rhs)?)) + } (Self::BigInt(lhs), Self::SmallInt(rhs)) => { Ok(Self::BigInt(try_divide_slices_right_upcast(lhs, rhs)?)) } @@ -967,6 +1291,10 @@ impl Div for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Int128(lhs), Self::TinyInt(rhs)) => { + Ok(Self::Int128(try_divide_slices_right_upcast(lhs, rhs)?)) + } (Self::Int128(lhs), Self::SmallInt(rhs)) => { Ok(Self::Int128(try_divide_slices_right_upcast(lhs, rhs)?)) } @@ -988,6 +1316,16 @@ impl Div for OwnedColumn { )?; Ok(Self::Decimal75(new_precision, new_scale, new_values)) } + + (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = try_divide_decimal_columns( + lhs_values, + rhs_values, + self.column_type(), + rhs.column_type(), + )?; + Ok(Self::Decimal75(new_precision, new_scale, new_values)) + } (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { let (new_precision, new_scale, new_values) = try_divide_decimal_columns( lhs_values, @@ -1076,6 +1414,14 @@ mod test { Err(ColumnOperationError::DifferentColumnLength { .. }) )); + let lhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let rhs = OwnedColumn::::TinyInt(vec![1, 2]); + let result = lhs.clone() + rhs.clone(); + assert!(matches!( + result, + Err(ColumnOperationError::DifferentColumnLength { .. }) + )); + let lhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); let rhs = OwnedColumn::::SmallInt(vec![1, 2]); let result = lhs.clone() + rhs.clone(); @@ -1105,6 +1451,26 @@ mod test { #[test] fn we_cannot_do_logical_operation_on_nonboolean_columns() { + let lhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let result = lhs.element_wise_and(&rhs); + assert!(matches!( + result, + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let result = lhs.element_wise_or(&rhs); + assert!(matches!( + result, + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let result = lhs.element_wise_not(); + assert!(matches!( + result, + Err(ColumnOperationError::UnaryOperationInvalidColumnType { .. }) + )); + let lhs = OwnedColumn::::Int(vec![1, 2, 3]); let rhs = OwnedColumn::::Int(vec![1, 2, 3]); let result = lhs.element_wise_and(&rhs); @@ -1158,6 +1524,16 @@ mod test { #[test] fn we_can_do_eq_operation() { // Integers + let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); + let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let result = lhs.element_wise_eq(&rhs); + assert_eq!( + result, + Ok(OwnedColumn::::Boolean(vec![ + true, false, false + ])) + ); + let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); let result = lhs.element_wise_eq(&rhs); @@ -1216,6 +1592,18 @@ mod test { ); // Decimals and integers + let lhs_scalars = [10, 2, 30].iter().map(Curve25519Scalar::from).collect(); + let rhs = OwnedColumn::::TinyInt(vec![1, -2, 3]); + let lhs = + OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 1, lhs_scalars); + let result = lhs.element_wise_eq(&rhs); + assert_eq!( + result, + Ok(OwnedColumn::::Boolean(vec![ + true, false, true + ])) + ); + let lhs_scalars = [10, 2, 30].iter().map(Curve25519Scalar::from).collect(); let rhs = OwnedColumn::::Int(vec![1, -2, 3]); let lhs = @@ -1243,6 +1631,16 @@ mod test { ); // Integers + let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); + let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let result = lhs.element_wise_le(&rhs); + assert_eq!( + result, + Ok(OwnedColumn::::Boolean(vec![ + true, false, true + ])) + ); + let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); let result = lhs.element_wise_le(&rhs); @@ -1269,6 +1667,18 @@ mod test { ); // Decimals and integers + let lhs_scalars = [10, -2, -30].iter().map(Curve25519Scalar::from).collect(); + let rhs = OwnedColumn::::TinyInt(vec![1, -20, 3]); + let lhs = + OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); + let result = lhs.element_wise_le(&rhs); + assert_eq!( + result, + Ok(OwnedColumn::::Boolean(vec![ + false, true, true + ])) + ); + let lhs_scalars = [10, -2, -30].iter().map(Curve25519Scalar::from).collect(); let rhs = OwnedColumn::::Int(vec![1, -20, 3]); let lhs = @@ -1296,6 +1706,16 @@ mod test { ); // Integers + let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); + let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let result = lhs.element_wise_ge(&rhs); + assert_eq!( + result, + Ok(OwnedColumn::::Boolean(vec![ + true, true, false + ])) + ); + let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); let result = lhs.element_wise_ge(&rhs); @@ -1322,6 +1742,18 @@ mod test { ); // Decimals and integers + let lhs_scalars = [10, -2, -30].iter().map(Curve25519Scalar::from).collect(); + let rhs = OwnedColumn::::TinyInt(vec![1_i8, -20, 3]); + let lhs = + OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); + let result = lhs.element_wise_ge(&rhs); + assert_eq!( + result, + Ok(OwnedColumn::::Boolean(vec![ + true, true, false + ])) + ); + let lhs_scalars = [10, -2, -30].iter().map(Curve25519Scalar::from).collect(); let rhs = OwnedColumn::::BigInt(vec![1_i64, -20, 3]); let lhs = @@ -1338,6 +1770,19 @@ mod test { #[test] fn we_cannot_do_comparison_on_columns_with_incompatible_types() { // Strings can't be compared with other types + let lhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let rhs = OwnedColumn::::VarChar( + ["Space", "and", "Time"] + .iter() + .map(|s| s.to_string()) + .collect(), + ); + let result = lhs.element_wise_le(&rhs); + assert!(matches!( + result, + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + let lhs = OwnedColumn::::Int(vec![1, 2, 3]); let rhs = OwnedColumn::::VarChar( ["Space", "and", "Time"] @@ -1364,6 +1809,14 @@ mod test { )); // Booleans can't be compared with other types + let lhs = OwnedColumn::::Boolean(vec![true, false, true]); + let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let result = lhs.element_wise_le(&rhs); + assert!(matches!( + result, + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Int(vec![1, 2, 3]); let result = lhs.element_wise_le(&rhs); @@ -1439,6 +1892,14 @@ mod test { #[test] fn we_can_add_integer_columns() { // lhs and rhs have the same precision + let lhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); + let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); + let result = lhs + rhs; + assert_eq!( + result, + Ok(OwnedColumn::::TinyInt(vec![2_i8, 4, 6])) + ); + let lhs = OwnedColumn::::SmallInt(vec![1_i16, 2, 3]); let rhs = OwnedColumn::::SmallInt(vec![1_i16, 2, 3]); let result = lhs + rhs; @@ -1448,6 +1909,14 @@ mod test { ); // lhs and rhs have different precisions + let lhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); + let rhs = OwnedColumn::::Int(vec![1_i32, 2, 3]); + let result = lhs + rhs; + assert_eq!( + result, + Ok(OwnedColumn::::Int(vec![2_i32, 4, 6])) + ); + let lhs = OwnedColumn::::Int128(vec![1_i128, 2, 3]); let rhs = OwnedColumn::::Int(vec![1_i32, 2, 3]); let result = lhs + rhs; @@ -1496,6 +1965,21 @@ mod test { ); // lhs is integer and rhs is decimal + let lhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); + let rhs_scalars = [1, 2, 3].iter().map(Curve25519Scalar::from).collect(); + let rhs = + OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); + let result = (lhs + rhs).unwrap(); + let expected_scalars = [101, 202, 303].iter().map(Curve25519Scalar::from).collect(); + assert_eq!( + result, + OwnedColumn::::Decimal75( + Precision::new(6).unwrap(), + 2, + expected_scalars + ) + ); + let lhs = OwnedColumn::::Int(vec![1, 2, 3]); let rhs_scalars = [1, 2, 3].iter().map(Curve25519Scalar::from).collect(); let rhs = @@ -1515,6 +1999,14 @@ mod test { #[test] fn we_can_try_subtract_integer_columns() { // lhs and rhs have the same precision + let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, 2]); + let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); + let result = lhs - rhs; + assert_eq!( + result, + Ok(OwnedColumn::::TinyInt(vec![3_i8, 3, -1])) + ); + let lhs = OwnedColumn::::Int(vec![4_i32, 5, 2]); let rhs = OwnedColumn::::Int(vec![1_i32, 2, 3]); let result = lhs - rhs; @@ -1524,6 +2016,14 @@ mod test { ); // lhs and rhs have different precisions + let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, 2]); + let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 5]); + let result = lhs - rhs; + assert_eq!( + result, + Ok(OwnedColumn::::BigInt(vec![3_i64, 3, -3])) + ); + let lhs = OwnedColumn::::Int(vec![3_i32, 2, 3]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 5]); let result = lhs - rhs; @@ -1572,6 +2072,21 @@ mod test { ); // lhs is integer and rhs is decimal + let lhs = OwnedColumn::::TinyInt(vec![4, 5, 2]); + let rhs_scalars = [1, 2, 3].iter().map(Curve25519Scalar::from).collect(); + let rhs = + OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); + let result = (lhs - rhs).unwrap(); + let expected_scalars = [399, 498, 197].iter().map(Curve25519Scalar::from).collect(); + assert_eq!( + result, + OwnedColumn::::Decimal75( + Precision::new(6).unwrap(), + 2, + expected_scalars + ) + ); + let lhs = OwnedColumn::::Int(vec![4, 5, 2]); let rhs_scalars = [1, 2, 3].iter().map(Curve25519Scalar::from).collect(); let rhs = @@ -1591,6 +2106,14 @@ mod test { #[test] fn we_can_try_multiply_integer_columns() { // lhs and rhs have the same precision + let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, -2]); + let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); + let result = lhs * rhs; + assert_eq!( + result, + Ok(OwnedColumn::::TinyInt(vec![4_i8, 10, -6])) + ); + let lhs = OwnedColumn::::BigInt(vec![4_i64, 5, -2]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 3]); let result = lhs * rhs; @@ -1600,6 +2123,14 @@ mod test { ); // lhs and rhs have different precisions + let lhs = OwnedColumn::::TinyInt(vec![3_i8, 2, 3]); + let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); + let result = lhs * rhs; + assert_eq!( + result, + Ok(OwnedColumn::::Int128(vec![3_i128, 4, 15])) + ); + let lhs = OwnedColumn::::Int(vec![3_i32, 2, 3]); let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); let result = lhs * rhs; @@ -1630,6 +2161,21 @@ mod test { ); // lhs is integer and rhs is decimal + let lhs = OwnedColumn::::TinyInt(vec![4, 5, 2]); + let rhs_scalars = [1, 2, 3].iter().map(Curve25519Scalar::from).collect(); + let rhs = + OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); + let result = (lhs * rhs).unwrap(); + let expected_scalars = [4, 10, 6].iter().map(Curve25519Scalar::from).collect(); + assert_eq!( + result, + OwnedColumn::::Decimal75( + Precision::new(9).unwrap(), + 2, + expected_scalars + ) + ); + let lhs = OwnedColumn::::Int(vec![4, 5, 2]); let rhs_scalars = [1, 2, 3].iter().map(Curve25519Scalar::from).collect(); let rhs = @@ -1649,6 +2195,14 @@ mod test { #[test] fn we_can_try_divide_integer_columns() { // lhs and rhs have the same precision + let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, -2]); + let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); + let result = lhs / rhs; + assert_eq!( + result, + Ok(OwnedColumn::::TinyInt(vec![4_i8, 2, 0])) + ); + let lhs = OwnedColumn::::BigInt(vec![4_i64, 5, -2]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 3]); let result = lhs / rhs; @@ -1658,6 +2212,14 @@ mod test { ); // lhs and rhs have different precisions + let lhs = OwnedColumn::::TinyInt(vec![3_i8, 2, 3]); + let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); + let result = lhs / rhs; + assert_eq!( + result, + Ok(OwnedColumn::::Int128(vec![3_i128, 1, 0])) + ); + let lhs = OwnedColumn::::Int(vec![3_i32, 2, 3]); let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); let result = lhs / rhs; @@ -1691,6 +2253,24 @@ mod test { ); // lhs is integer and rhs is decimal + let lhs = OwnedColumn::::TinyInt(vec![4, 5, 3]); + let rhs_scalars = [-1, 2, 3].iter().map(Curve25519Scalar::from).collect(); + let rhs = + OwnedColumn::::Decimal75(Precision::new(3).unwrap(), 2, rhs_scalars); + let result = (lhs / rhs).unwrap(); + let expected_scalars = [-400_000_000, 250_000_000, 100_000_000] + .iter() + .map(Curve25519Scalar::from) + .collect(); + assert_eq!( + result, + OwnedColumn::::Decimal75( + Precision::new(11).unwrap(), + 6, + expected_scalars + ) + ); + let lhs = OwnedColumn::::SmallInt(vec![4, 5, 3]); let rhs_scalars = [-1, 2, 3].iter().map(Curve25519Scalar::from).collect(); let rhs = diff --git a/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs b/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs index 5cf61c85a..05b3cea6f 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs @@ -91,6 +91,7 @@ impl DataAccessor for OwnedTableTestA .unwrap() { OwnedColumn::Boolean(col) => Column::Boolean(col), + OwnedColumn::TinyInt(col) => Column::TinyInt(col), OwnedColumn::SmallInt(col) => Column::SmallInt(col), OwnedColumn::Int(col) => Column::Int(col), OwnedColumn::BigInt(col) => Column::BigInt(col), diff --git a/crates/proof-of-sql/src/base/database/owned_table_utility.rs b/crates/proof-of-sql/src/base/database/owned_table_utility.rs index 658dfb565..a008aa235 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_utility.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_utility.rs @@ -48,6 +48,27 @@ pub fn owned_table( OwnedTable::try_from_iter(iter).unwrap() } +/// Creates a (Identifier, `OwnedColumn`) pair for a tinyint column. +/// This is primarily intended for use in conjunction with [`owned_table`]. +/// # Example +/// ``` +/// use proof_of_sql::base::{database::owned_table_utility::*, scalar::Curve25519Scalar}; +/// let result = owned_table::([ +/// tinyint("a", [1_i8, 2, 3]), +/// ]); +///``` +/// # Panics +/// - Panics if `name.parse()` fails to convert the name into an `Identifier`. +pub fn tinyint( + name: impl Deref, + data: impl IntoIterator>, +) -> (Identifier, OwnedColumn) { + ( + name.parse().unwrap(), + OwnedColumn::TinyInt(data.into_iter().map(Into::into).collect()), + ) +} + /// Creates a `(Identifier, OwnedColumn)` pair for a smallint column. /// This is primarily intended for use in conjunction with [`owned_table`]. /// # Example diff --git a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs index 815532fc2..220d8c860 100644 --- a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs +++ b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs @@ -2,7 +2,7 @@ use crate::base::database::ColumnType; use arrow::{ array::{ Array, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, Int64Array, - StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, datatypes::{i256, DataType, Field, Schema, TimeUnit}, @@ -69,7 +69,14 @@ pub fn make_random_test_accessor_data( let boolean_values: Vec = values.iter().map(|x| x % 2 != 0).collect(); columns.push(Arc::new(BooleanArray::from(boolean_values))); } - + ColumnType::TinyInt => { + column_fields.push(Field::new(*col_name, DataType::Int8, false)); + let values: Vec = values + .iter() + .map(|x| ((*x >> 56) as i8)) // Shift right to align the lower 8 bits + .collect(); + columns.push(Arc::new(Int8Array::from(values))); + } ColumnType::SmallInt => { column_fields.push(Field::new(*col_name, DataType::Int16, false)); let values: Vec = values @@ -179,6 +186,7 @@ mod tests { ("c", ColumnType::Int128), ("d", ColumnType::SmallInt), ("e", ColumnType::Int), + ("f", ColumnType::TinyInt), ]; let data1 = make_random_test_accessor_data(&mut rng, &cols, &descriptor); diff --git a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs index bce68ffb7..4425ac056 100644 --- a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs +++ b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs @@ -103,6 +103,7 @@ impl MultilinearExtension for &Column<'_, S> { Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { c.inner_product(evaluation_vec) } + Column::TinyInt(c) => c.inner_product(evaluation_vec), Column::SmallInt(c) => c.inner_product(evaluation_vec), Column::Int(c) => c.inner_product(evaluation_vec), Column::BigInt(c) | Column::TimestampTZ(_, _, c) => c.inner_product(evaluation_vec), @@ -116,6 +117,7 @@ impl MultilinearExtension for &Column<'_, S> { Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { c.mul_add(res, multiplier); } + Column::TinyInt(c) => c.mul_add(res, multiplier), Column::SmallInt(c) => c.mul_add(res, multiplier), Column::Int(c) => c.mul_add(res, multiplier), Column::BigInt(c) | Column::TimestampTZ(_, _, c) => c.mul_add(res, multiplier), @@ -129,6 +131,7 @@ impl MultilinearExtension for &Column<'_, S> { Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { c.to_sumcheck_term(num_vars) } + Column::TinyInt(c) => c.to_sumcheck_term(num_vars), Column::SmallInt(c) => c.to_sumcheck_term(num_vars), Column::Int(c) => c.to_sumcheck_term(num_vars), Column::BigInt(c) | Column::TimestampTZ(_, _, c) => c.to_sumcheck_term(num_vars), @@ -142,6 +145,7 @@ impl MultilinearExtension for &Column<'_, S> { Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { MultilinearExtension::::id(c) } + Column::TinyInt(c) => MultilinearExtension::::id(c), Column::SmallInt(c) => MultilinearExtension::::id(c), Column::Int(c) => MultilinearExtension::::id(c), Column::BigInt(c) | Column::TimestampTZ(_, _, c) => MultilinearExtension::::id(c), diff --git a/crates/proof-of-sql/src/base/scalar/mod.rs b/crates/proof-of-sql/src/base/scalar/mod.rs index e58ddfe57..6ba7e8208 100644 --- a/crates/proof-of-sql/src/base/scalar/mod.rs +++ b/crates/proof-of-sql/src/base/scalar/mod.rs @@ -38,6 +38,7 @@ pub trait Scalar: + num_traits::Zero + for<'a> core::convert::From<&'a Self> // Required for `Column` to implement `MultilinearExtension` + for<'a> core::convert::From<&'a bool> // Required for `Column` to implement `MultilinearExtension` + + for<'a> core::convert::From<&'a i8> // Required for `Column` to implement `MultilinearExtension` + for<'a> core::convert::From<&'a i16> // Required for `Column` to implement `MultilinearExtension` + for<'a> core::convert::From<&'a i32> // Required for `Column` to implement `MultilinearExtension` + for<'a> core::convert::From<&'a i64> // Required for `Column` to implement `MultilinearExtension` @@ -67,6 +68,7 @@ pub trait Scalar: + core::convert::From + core::convert::From + core::convert::From + + core::convert::From + core::convert::From + core::convert::Into + TryFrom diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_cpu.rs b/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_cpu.rs index 0ab4403b9..611e84f1b 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_cpu.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_cpu.rs @@ -60,6 +60,7 @@ fn compute_dory_commitment( ) -> DoryCommitment { match committable_column { CommittableColumn::Scalar(column) => compute_dory_commitment_impl(column, offset, setup), + CommittableColumn::TinyInt(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::SmallInt(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::Int(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::BigInt(column) => compute_dory_commitment_impl(column, offset, setup), diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs index 6b2a7a93b..7187f5abe 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs @@ -45,13 +45,13 @@ fn compute_dory_commitment( setup: &ProverSetup, ) -> DynamicDoryCommitment { match committable_column { + CommittableColumn::Scalar(column) => compute_dory_commitment_impl(column, offset, setup), + CommittableColumn::TinyInt(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::SmallInt(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::Int(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::BigInt(column) => compute_dory_commitment_impl(column, offset, setup), CommittableColumn::Int128(column) => compute_dory_commitment_impl(column, offset, setup), - CommittableColumn::VarChar(column) - | CommittableColumn::Scalar(column) - | CommittableColumn::Decimal75(_, _, column) => { + CommittableColumn::VarChar(column) | CommittableColumn::Decimal75(_, _, column) => { compute_dory_commitment_impl(column, offset, setup) } CommittableColumn::Boolean(column) => compute_dory_commitment_impl(column, offset, setup), diff --git a/crates/proof-of-sql/src/proof_primitive/dory/offset_to_bytes.rs b/crates/proof-of-sql/src/proof_primitive/dory/offset_to_bytes.rs index 82287b5c5..f884813b3 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/offset_to_bytes.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/offset_to_bytes.rs @@ -8,6 +8,13 @@ impl OffsetToBytes<1> for u8 { } } +impl OffsetToBytes<1> for i8 { + fn offset_to_bytes(&self) -> [u8; 1] { + let shifted = self.wrapping_sub(i8::MIN); + shifted.to_le_bytes() + } +} + impl OffsetToBytes<2> for i16 { fn offset_to_bytes(&self) -> [u8; 2] { let shifted = self.wrapping_sub(i16::MIN); diff --git a/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs b/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs index 1ce4d4fce..e77410cb4 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs @@ -57,6 +57,7 @@ fn output_bit_table( /// * `column_type` - The type of a committable column. const fn min_as_f(column_type: ColumnType) -> F { match column_type { + ColumnType::TinyInt => MontFp!("-128"), ColumnType::SmallInt => MontFp!("-32768"), ColumnType::Int => MontFp!("-2147483648"), ColumnType::BigInt | ColumnType::TimestampTZ(_, _) => MontFp!("-9223372036854775808"), @@ -375,6 +376,17 @@ pub fn bit_table_and_scalars_for_packed_msm( .iter() .enumerate() .for_each(|(i, column)| match column { + CommittableColumn::TinyInt(column) => { + pack_bit( + column, + &mut packed_scalars, + cumulative_bit_sum_table[i], + offset, + committable_columns[i].column_type().byte_size(), + bit_table_full_sum_in_bytes, + num_matrix_commitment_columns, + ); + } CommittableColumn::SmallInt(column) => { pack_bit( column, diff --git a/crates/proof-of-sql/src/sql/proof/provable_query_result.rs b/crates/proof-of-sql/src/sql/proof/provable_query_result.rs index 32bbcc7e1..57d037780 100644 --- a/crates/proof-of-sql/src/sql/proof/provable_query_result.rs +++ b/crates/proof-of-sql/src/sql/proof/provable_query_result.rs @@ -119,6 +119,7 @@ impl ProvableQueryResult { for index in self.indexes.iter() { let (x, sz) = match field.data_type() { ColumnType::Boolean => decode_and_convert::(&self.data[offset..]), + ColumnType::TinyInt => decode_and_convert::(&self.data[offset..]), ColumnType::SmallInt => decode_and_convert::(&self.data[offset..]), ColumnType::Int => decode_and_convert::(&self.data[offset..]), ColumnType::BigInt => decode_and_convert::(&self.data[offset..]), @@ -172,6 +173,11 @@ impl ProvableQueryResult { offset += num_read; Ok((field.name(), OwnedColumn::Boolean(col))) } + ColumnType::TinyInt => { + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; + offset += num_read; + Ok((field.name(), OwnedColumn::TinyInt(col))) + } ColumnType::SmallInt => { let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; diff --git a/crates/proof-of-sql/src/sql/proof/provable_result_column.rs b/crates/proof-of-sql/src/sql/proof/provable_result_column.rs index 633eff1c2..bf883245b 100644 --- a/crates/proof-of-sql/src/sql/proof/provable_result_column.rs +++ b/crates/proof-of-sql/src/sql/proof/provable_result_column.rs @@ -35,6 +35,7 @@ impl ProvableResultColumn for Column<'_, S> { fn num_bytes(&self, selection: &Indexes) -> usize { match self { Column::Boolean(col) => col.num_bytes(selection), + Column::TinyInt(col) => col.num_bytes(selection), Column::SmallInt(col) => col.num_bytes(selection), Column::Int(col) => col.num_bytes(selection), Column::BigInt(col) | Column::TimestampTZ(_, _, col) => col.num_bytes(selection), @@ -47,6 +48,7 @@ impl ProvableResultColumn for Column<'_, S> { fn write(&self, out: &mut [u8], selection: &Indexes) -> usize { match self { Column::Boolean(col) => col.write(out, selection), + Column::TinyInt(col) => col.write(out, selection), Column::SmallInt(col) => col.write(out, selection), Column::Int(col) => col.write(out, selection), Column::BigInt(col) | Column::TimestampTZ(_, _, col) => col.write(out, selection), diff --git a/crates/proof-of-sql/src/sql/proof/verifiable_query_result.rs b/crates/proof-of-sql/src/sql/proof/verifiable_query_result.rs index 8a65d66e7..78282967d 100644 --- a/crates/proof-of-sql/src/sql/proof/verifiable_query_result.rs +++ b/crates/proof-of-sql/src/sql/proof/verifiable_query_result.rs @@ -159,6 +159,7 @@ fn make_empty_query_result(result_fields: Vec) -> QueryR field.name(), match field.data_type() { ColumnType::Boolean => OwnedColumn::Boolean(vec![]), + ColumnType::TinyInt => OwnedColumn::TinyInt(vec![]), ColumnType::SmallInt => OwnedColumn::SmallInt(vec![]), ColumnType::Int => OwnedColumn::Int(vec![]), ColumnType::BigInt => OwnedColumn::BigInt(vec![]), diff --git a/crates/proof-of-sql/tests/integration_tests.rs b/crates/proof-of-sql/tests/integration_tests.rs index 76425c8d1..864b64c83 100644 --- a/crates/proof-of-sql/tests/integration_tests.rs +++ b/crates/proof-of-sql/tests/integration_tests.rs @@ -211,6 +211,7 @@ fn we_can_prove_a_basic_query_containing_extrema_with_curve25519() { accessor.add_table( "sxt.table".parse().unwrap(), owned_table([ + tinyint("tinyint", [i8::MIN, 0, i8::MAX]), smallint("smallint", [i16::MIN, 0, i16::MAX]), int("int", [i32::MIN, 0, i32::MAX]), bigint("bigint", [i64::MIN, 0, i64::MAX]), @@ -231,6 +232,7 @@ fn we_can_prove_a_basic_query_containing_extrema_with_curve25519() { .unwrap() .table; let expected_result = owned_table([ + tinyint("tinyint", [i8::MIN, 0, i8::MAX]), smallint("smallint", [i16::MIN, 0, i16::MAX]), int("int", [i32::MIN, 0, i32::MAX]), bigint("bigint", [i64::MIN, 0, i64::MAX]), @@ -252,6 +254,7 @@ fn we_can_prove_a_basic_query_containing_extrema_with_dory() { accessor.add_table( "sxt.table".parse().unwrap(), owned_table([ + tinyint("tinyint", [i8::MIN, 0, i8::MAX]), smallint("smallint", [i16::MIN, 0, i16::MAX]), int("int", [i32::MIN, 0, i32::MAX]), bigint("bigint", [i64::MIN, 0, i64::MAX]), @@ -277,6 +280,7 @@ fn we_can_prove_a_basic_query_containing_extrema_with_dory() { .unwrap() .table; let expected_result = owned_table([ + tinyint("tinyint", [i8::MIN, 0, i8::MAX]), smallint("smallint", [i16::MIN, 0, i16::MAX]), int("int", [i32::MIN, 0, i32::MAX]), bigint("bigint", [i64::MIN, 0, i64::MAX]), @@ -840,3 +844,34 @@ fn we_can_prove_a_query_with_overflow_with_dory() { Err(QueryError::Overflow) )); } + +#[test] +#[cfg(feature = "blitzar")] +fn we_can_perform_arithmetic_and_conditional_operations_on_tinyint() { + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + accessor.add_table( + "sxt.table".parse().unwrap(), + owned_table([ + tinyint("a", [3_i8, 5, 2, 1]), + tinyint("b", [2_i8, 1, 3, 4]), + tinyint("c", [1_i8, 4, 5, 2]), + ]), + 0, + ); + let query = QueryExpr::try_new( + "SELECT a*b+b+c as result FROM table WHERE a>b OR c=4" + .parse() + .unwrap(), + "sxt".parse().unwrap(), + &accessor, + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &()); + let owned_table_result = proof + .verify(query.proof_expr(), &accessor, &serialized_result, &()) + .unwrap() + .table; + let expected_result = owned_table([tinyint("result", [9_i8, 10])]); + assert_eq!(owned_table_result, expected_result); +} diff --git a/crates/proof-of-sql/tests/timestamp_integration_tests.rs b/crates/proof-of-sql/tests/timestamp_integration_tests.rs index de8ffb948..e803d5507 100644 --- a/crates/proof-of-sql/tests/timestamp_integration_tests.rs +++ b/crates/proof-of-sql/tests/timestamp_integration_tests.rs @@ -29,6 +29,7 @@ fn we_can_prove_a_basic_query_containing_rfc3339_timestamp_with_dory() { accessor.add_table( "sxt.table".parse().unwrap(), owned_table([ + tinyint("tinyint", [i8::MIN, 0, i8::MAX]), smallint("smallint", [i16::MIN, 0, i16::MAX]), int("int", [i32::MIN, 0, i32::MAX]), bigint("bigint", [i64::MIN, 0, i64::MAX]), diff --git a/docs/SQLSyntaxSpecification.md b/docs/SQLSyntaxSpecification.md index 571a7bfe1..20584a0f2 100644 --- a/docs/SQLSyntaxSpecification.md +++ b/docs/SQLSyntaxSpecification.md @@ -15,6 +15,7 @@ FROM table * DataTypes - Bool / Boolean - Numeric Types + * TinyInt (8 bits) * SmallInt (16 bits) * Int / Integer (32 bits) * BigInt (64 bits)