From 54fbc9a1f9163cb9554755280eaab9672f1c9c2b Mon Sep 17 00:00:00 2001 From: Ian Alexander Joiner <14581281+iajoiner@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:23:27 -0400 Subject: [PATCH] feat!: remove old postprocessing and polars entirely (#84) # Rationale for this change This PR is the last one in polars removal which is the removal itself. It should only be merged once everything else in the process get merged. # What changes are included in this PR? - remove `sql::transform` - remove Arrow <-> Polars conversion - remove dependencies `polars`, `dyn_partial_eq` and `typetag` # Are these changes tested? N/A --- Cargo.toml | 5 +- crates/proof-of-sql/Cargo.toml | 6 +- crates/proof-of-sql/src/base/database/mod.rs | 5 - .../record_batch_dataframe_conversion.rs | 311 ----------- crates/proof-of-sql/src/sql/mod.rs | 1 - .../src/sql/parse/result_expr_builder.rs | 81 --- .../src/sql/transform/composition_expr.rs | 43 -- .../sql/transform/composition_expr_test.rs | 47 -- .../src/sql/transform/data_frame_expr.rs | 9 - .../src/sql/transform/group_by_expr.rs | 118 ---- .../src/sql/transform/group_by_expr_test.rs | 379 ------------- crates/proof-of-sql/src/sql/transform/mod.rs | 62 --- .../src/sql/transform/order_by_exprs.rs | 95 ---- .../src/sql/transform/order_by_exprs_test.rs | 231 -------- .../src/sql/transform/polars_arithmetic.rs | 510 ------------------ .../src/sql/transform/polars_conversions.rs | 107 ---- .../src/sql/transform/record_batch_expr.rs | 32 -- .../src/sql/transform/result_expr.rs | 51 -- .../src/sql/transform/select_expr.rs | 74 --- .../src/sql/transform/select_expr_test.rs | 139 ----- .../src/sql/transform/slice_expr.rs | 40 -- .../src/sql/transform/slice_expr_test.rs | 121 ----- .../src/sql/transform/test_utility.rs | 84 --- .../src/sql/transform/to_polars_expr.rs | 60 --- 24 files changed, 2 insertions(+), 2609 deletions(-) delete mode 100644 crates/proof-of-sql/src/base/database/record_batch_dataframe_conversion.rs delete mode 100644 crates/proof-of-sql/src/sql/parse/result_expr_builder.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/composition_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/composition_expr_test.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/data_frame_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/group_by_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/group_by_expr_test.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/mod.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/order_by_exprs.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/order_by_exprs_test.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/polars_arithmetic.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/polars_conversions.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/record_batch_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/result_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/select_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/select_expr_test.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/slice_expr.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/slice_expr_test.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/test_utility.rs delete mode 100644 crates/proof-of-sql/src/sql/transform/to_polars_expr.rs diff --git a/Cargo.toml b/Cargo.toml index 2bbe30b2b..1ba1d81a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,14 +25,13 @@ bigdecimal = { version = "0.4.5", features = ["serde"] } blake3 = { version = "1.3.3" } blitzar = { version = "3.0.2" } bumpalo = { version = "3.11.0" } -bytemuck = {version = "1.14.2" } +bytemuck = {version = "1.16.3", features = ["derive"]} byte-slice-cast = { version = "1.2.1" } clap = { version = "4.5.4" } criterion = { version = "0.5.1" } chrono = { version = "0.4.38" } curve25519-dalek = { version = "4", features = ["rand_core"] } derive_more = { version = "0.99" } -dyn_partial_eq = { version = "0.1.2" } flexbuffers = { version = "2.0.0" } indexmap = { version = "2.1" } itertools = { version = "0.13.0" } @@ -43,7 +42,6 @@ num-traits = { version = "0.2" } num-bigint = { version = "0.4.4", default-features = false } opentelemetry = { version = "0.23.0" } opentelemetry-jaeger = { version = "0.20.0" } -polars = { version = "0.33.1", default-features = false, features = ["dtype-i16"] } postcard = { version = "1.0" } proof-of-sql = { path = "crates/proof-of-sql" } # We automatically update this line during release. So do not modify it! proof-of-sql-parser = { path = "crates/proof-of-sql-parser" } # We automatically update this line during release. So do not modify it! @@ -56,7 +54,6 @@ thiserror = { version = "1" } tracing = { version = "0.1.36" } tracing-opentelemetry = { version = "0.22.0" } tracing-subscriber = { version = "0.3.0" } -typetag = { version = "0.2.13" } wasm-bindgen = { version = "0.2.92" } zerocopy = { version = "0.7.34" } diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index 65ee86241..e6dc8b8c9 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -32,14 +32,12 @@ byte-slice-cast = { workspace = true } curve25519-dalek = { workspace = true, features = ["serde"] } chrono = {workspace = true, features = ["serde"]} derive_more = { workspace = true } -dyn_partial_eq = { workspace = true } -indexmap = { workspace = true } +indexmap = { workspace = true, features = ["serde"] } itertools = { workspace = true } lazy_static = { workspace = true } merlin = { workspace = true } num-traits = { workspace = true } num-bigint = { workspace = true, default-features = false } -polars = { workspace = true, features = ["lazy", "bigidx", "dtype-decimal", "serde-lazy"] } postcard = { workspace = true, features = ["alloc"] } proof-of-sql-parser = { workspace = true } rand = { workspace = true, optional = true } @@ -48,7 +46,6 @@ serde = { workspace = true, features = ["serde_derive"] } serde_json = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true, features = ["attributes"] } -typetag = { workspace = true } zerocopy = { workspace = true } [dev_dependencies] @@ -58,7 +55,6 @@ clap = { workspace = true, features = ["derive"] } criterion = { workspace = true, features = ["html_reports"] } opentelemetry = { workspace = true } opentelemetry-jaeger = { workspace = true } -polars = { workspace = true, features = ["lazy"] } rand = { workspace = true } rand_core = { workspace = true } serde_json = { workspace = true } diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 9d854795b..6db25c2b4 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -24,11 +24,6 @@ pub use table_ref::TableRef; mod arrow_array_to_column_conversion; pub use arrow_array_to_column_conversion::{ArrayRefExt, ArrowArrayToColumnConversionError}; -mod record_batch_dataframe_conversion; -pub(crate) use record_batch_dataframe_conversion::{ - dataframe_to_record_batch, record_batch_to_dataframe, -}; - mod record_batch_utility; pub use record_batch_utility::ToArrow; diff --git a/crates/proof-of-sql/src/base/database/record_batch_dataframe_conversion.rs b/crates/proof-of-sql/src/base/database/record_batch_dataframe_conversion.rs deleted file mode 100644 index 4c49a148d..000000000 --- a/crates/proof-of-sql/src/base/database/record_batch_dataframe_conversion.rs +++ /dev/null @@ -1,311 +0,0 @@ -use arrow::{ - array::{ - Array, BooleanArray, Decimal128Array, Int16Array, Int32Array, Int64Array, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - }, - datatypes::{DataType, Field, Schema, TimeUnit as ArrowTimeUnit}, - record_batch::RecordBatch, -}; -use polars::{ - frame::DataFrame, - prelude::{ChunkedArray, NamedFrom}, - series::{IntoSeries, Series}, -}; -use std::sync::Arc; - -/// Convert a RecordBatch to a polars DataFrame -/// Note: this explicitly does not check that Decimal128(38,0) values are 38 digits. -pub fn record_batch_to_dataframe(record_batch: RecordBatch) -> Option { - let series: Option> = record_batch - .schema() - .fields() - .iter() - .zip(record_batch.columns().iter()) - .map(|(f, col)| { - Some(match f.data_type() { - arrow::datatypes::DataType::Boolean => { - let data = col - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect::>>() - .unwrap(); - - Series::new(f.name(), data) - } - arrow::datatypes::DataType::Int16 => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - - Series::new(f.name(), data) - } - arrow::datatypes::DataType::Int32 => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - - Series::new(f.name(), data) - } - arrow::datatypes::DataType::Int64 => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - - Series::new(f.name(), data) - } - - arrow::datatypes::DataType::Utf8 => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| (0..array.len()).map(|i| array.value(i)).collect::>()) - .unwrap(); - - Series::new(f.name(), data) - } - arrow::datatypes::DataType::Decimal128(38, 0) => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - - ChunkedArray::from_vec(f.name(), data.to_vec()) - .into_decimal_unchecked(Some(38), 0) - // Note: we make this unchecked because if record batch has values that overflow 38 digits, so should the data frame. - .into_series() - } - arrow::datatypes::DataType::Timestamp(time_unit, _timezone_option) => { - match time_unit { - arrow::datatypes::TimeUnit::Second => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - Series::new(f.name(), data) - } - arrow::datatypes::TimeUnit::Millisecond => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - Series::new(f.name(), data) - } - arrow::datatypes::TimeUnit::Microsecond => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - Series::new(f.name(), data) - } - arrow::datatypes::TimeUnit::Nanosecond => { - let data = col - .as_any() - .downcast_ref::() - .map(|array| array.values()) - .unwrap(); - Series::new(f.name(), data) - } - } - } - _ => None?, - }) - }) - .collect(); - - Some(DataFrame::new(series?).unwrap()) -} - -/// Convert a polars DataFrame to a RecordBatch -/// Note: this does not check that Decimal128(38,0) values are 38 digits. -pub fn dataframe_to_record_batch(data: DataFrame) -> Option { - assert!(!data.is_empty()); - - let mut column_fields: Vec<_> = Vec::with_capacity(data.width()); - let mut columns: Vec> = Vec::with_capacity(data.width()); - - for (field, series) in data.fields().iter().zip(data.get_columns().iter()) { - let dt = match field.data_type() { - polars::datatypes::DataType::Boolean => { - let col = series - .bool() - .unwrap() - .into_iter() - .collect::>>() - .unwrap(); - - columns.push(Arc::new(BooleanArray::from(col))); - - DataType::Boolean - } - polars::datatypes::DataType::Int16 => { - let col = series.i16().unwrap().cont_slice().unwrap(); - - columns.push(Arc::new(Int16Array::from(col.to_vec()))); - - DataType::Int16 - } - polars::datatypes::DataType::Int32 => { - let col = series.i32().unwrap().cont_slice().unwrap(); - - columns.push(Arc::new(Int32Array::from(col.to_vec()))); - - DataType::Int32 - } - polars::datatypes::DataType::Int64 => { - let col = series.i64().unwrap().cont_slice().unwrap(); - - columns.push(Arc::new(Int64Array::from(col.to_vec()))); - - DataType::Int64 - } - // This code handles a specific case where a Polars DataFrame has an unsigned 64-bit integer (u64) data type, - // which only occurs when using the `count` function for aggregation. - polars::datatypes::DataType::UInt64 => { - // Retrieve the column as a contiguous slice of u64 values. - let col = series.u64().unwrap().cont_slice().unwrap(); - - // Cast the column to a supported i64 data type. - // Note that this operation should never overflow - // unless the database has around 2^64 rows, which is unfeasible. - let col = col.iter().map(|v| *v as i64).collect::>(); - - columns.push(Arc::new(Int64Array::from(col))); - - DataType::Int64 - } - polars::datatypes::DataType::Utf8 => { - let col: Vec<_> = series - .utf8() - .unwrap() - .into_iter() - .map(|opt_v| opt_v.unwrap()) - .collect(); - - columns.push(Arc::new(StringArray::from(col))); - - DataType::Utf8 - } - polars::datatypes::DataType::Decimal(Some(38), Some(0)) => { - let col = series.decimal().unwrap().cont_slice().unwrap(); - - columns.push(Arc::new( - Decimal128Array::from(col.to_vec()) - .with_precision_and_scale(38, 0) - .unwrap(), - )); - - DataType::Decimal128(38, 0) - } - // NOTE: Polars does not support seconds - polars::datatypes::DataType::Datetime(timeunit, timezone) => { - let col = series.i64().unwrap().cont_slice().unwrap(); - let timezone_arc = timezone.as_ref().map(|tz| Arc::from(tz.as_str())); - let arrow_array: Arc = match timeunit { - polars::datatypes::TimeUnit::Nanoseconds => { - Arc::new(TimestampNanosecondArray::with_timezone_opt( - col.to_vec().into(), - timezone_arc, - )) - } - polars::datatypes::TimeUnit::Microseconds => { - Arc::new(TimestampMicrosecondArray::with_timezone_opt( - col.to_vec().into(), - timezone_arc, - )) - } - polars::datatypes::TimeUnit::Milliseconds => { - Arc::new(TimestampMillisecondArray::with_timezone_opt( - col.to_vec().into(), - timezone_arc, - )) - } - }; - columns.push(arrow_array); - DataType::Timestamp( - match timeunit { - polars::datatypes::TimeUnit::Nanoseconds => ArrowTimeUnit::Nanosecond, - polars::datatypes::TimeUnit::Microseconds => ArrowTimeUnit::Microsecond, - polars::datatypes::TimeUnit::Milliseconds => ArrowTimeUnit::Millisecond, - }, - None, - ) - } - _ => return None, - }; - - column_fields.push(Field::new(field.name().as_str(), dt, false)); - } - - let schema = Arc::new(Schema::new(column_fields)); - - RecordBatch::try_new(schema, columns).ok() -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::record_batch; - - #[test] - fn we_can_convert_record_batches_to_dataframes() { - let recordbatch = record_batch!( - "boolean" => [true, false, true, false], - "bigint" => [3214_i64, 34, 999, 888], - "varchar" => ["a", "fg", "zzz", "yyy"], - "int128" => [123_i128, 1010, i128::MAX, i128::MIN + 1] - ); - //Note: to_string() can't handle i128:MIN within a dataframe. - let mut dataframe = polars::df!( - "boolean" => [true, false, true, false], - "bigint" => [3214_i64, 34_i64, 999, 888], - "varchar" => ["a", "fg", "zzz", "yyy"] - ) - .unwrap(); - dataframe - .with_column( - ChunkedArray::from_vec("int128", vec![123_i128, 1010, i128::MAX, i128::MIN + 1]) - .into_decimal_unchecked(Some(38), 0) - .into_series(), - ) - .unwrap(); - let df = record_batch_to_dataframe(recordbatch).unwrap(); - assert_eq!(dataframe.to_string(), df.to_string()); - } - - #[test] - fn we_can_convert_dataframes_to_record_batches() { - let recordbatch = record_batch!( - "boolean" => [true, false, true, false], - "bigint" => [3214_i64, 34, 999, 888], - "varchar" => ["a", "fg", "zzz", "yyy"], - "int128" => [123_i128, 1010, i128::MAX, i128::MIN] - ); - let mut dataframe = polars::df!( - "boolean" => [true, false, true, false], - "bigint" => [3214_i64, 34_i64, 999, 888], - "varchar" => ["a", "fg", "zzz", "yyy"] - ) - .unwrap(); - dataframe - .with_column( - ChunkedArray::from_vec("int128", vec![123_i128, 1010, i128::MAX, i128::MIN]) - .into_decimal_unchecked(Some(38), 0) - .into_series(), - ) - .unwrap(); - assert_eq!(dataframe_to_record_batch(dataframe).unwrap(), recordbatch); - } -} diff --git a/crates/proof-of-sql/src/sql/mod.rs b/crates/proof-of-sql/src/sql/mod.rs index f0ab23b0f..2f8f13453 100644 --- a/crates/proof-of-sql/src/sql/mod.rs +++ b/crates/proof-of-sql/src/sql/mod.rs @@ -3,4 +3,3 @@ pub mod ast; pub mod parse; pub mod postprocessing; pub mod proof; -pub mod transform; diff --git a/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs deleted file mode 100644 index 0dda9bb26..000000000 --- a/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::sql::transform::{CompositionExpr, GroupByExpr, OrderByExprs, SelectExpr, SliceExpr}; -use proof_of_sql_parser::{ - intermediate_ast::{AliasedResultExpr, Expression, OrderBy, Slice}, - Identifier, -}; - -/// A builder for `ResultExpr` nodes. -#[derive(Default)] -pub struct ResultExprBuilder { - composition: CompositionExpr, -} - -impl ResultExprBuilder { - /// Chain a new `GroupByExpr` to the current `ResultExpr`. - pub fn add_group_by_exprs( - mut self, - by_exprs: &[Identifier], - aliased_exprs: &[AliasedResultExpr], - ) -> Self { - if by_exprs.is_empty() { - return self; - } - self.composition - .add(Box::new(GroupByExpr::new(by_exprs, aliased_exprs))); - self - } - - /// Chain a new `SelectExpr` to the current `ResultExpr`. - pub fn add_select_exprs(mut self, aliased_exprs: &[AliasedResultExpr]) -> Self { - assert!(!aliased_exprs.is_empty()); - if !self.composition.is_empty() { - // The only transformation before a select is a group by. - // GROUP BY modifies the schema, so we need to - // update the code to reflect the changes. - let exprs: Vec<_> = aliased_exprs - .iter() - .map(|aliased_expr| Expression::Column(aliased_expr.alias)) - .collect(); - self.composition - .add(Box::new(SelectExpr::new_from_expressions(&exprs))); - } else { - self.composition - .add(Box::new(SelectExpr::new_from_aliased_result_exprs( - aliased_exprs, - ))); - } - self - } - - /// Chain a new `OrderByExprs` to the current `ResultExpr`. - pub fn add_order_by_exprs(mut self, by_exprs: Vec) -> Self { - if !by_exprs.is_empty() { - self.composition.add(Box::new(OrderByExprs::new(by_exprs))); - } - self - } - - /// Chain a new `SliceExpr` to the current `ResultExpr`. - pub fn add_slice_expr(mut self, slice: &Option) -> Self { - let (number_rows, offset_value) = match slice { - Some(Slice { - number_rows, - offset_value, - }) => (*number_rows, *offset_value), - None => (u64::MAX, 0), - }; - - // we don't need to add a slice transformation if - // we are not limiting or shifting the number of rows - if number_rows != u64::MAX || offset_value != 0 { - self.composition - .add(Box::new(SliceExpr::new(number_rows, offset_value))); - } - self - } - - /// Build a `ResultExpr` from the current state of the builder. - pub fn build(self) -> crate::sql::transform::ResultExpr { - crate::sql::transform::ResultExpr::new(Box::new(self.composition)) - } -} diff --git a/crates/proof-of-sql/src/sql/transform/composition_expr.rs b/crates/proof-of-sql/src/sql/transform/composition_expr.rs deleted file mode 100644 index 6474dc75e..000000000 --- a/crates/proof-of-sql/src/sql/transform/composition_expr.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::sql::transform::RecordBatchExpr; -use arrow::record_batch::RecordBatch; -use dyn_partial_eq::DynPartialEq; -use serde::{Deserialize, Serialize}; - -/// A node representing a list of transformations to be applied to a `LazyFrame`. -#[derive(Debug, Default, DynPartialEq, PartialEq, Serialize, Deserialize)] -pub struct CompositionExpr { - transformations: Vec>, -} - -impl CompositionExpr { - /// Create a new `CompositionExpr` node. - pub fn new(transformation: Box) -> Self { - Self { - transformations: vec![transformation], - } - } - - /// Verify if the `CompositionExpr` node is empty. - pub fn is_empty(&self) -> bool { - self.transformations.is_empty() - } - - /// Append a new transformation to the end of the current `CompositionExpr` node. - pub fn add(&mut self, transformation: Box) { - self.transformations.push(transformation); - } -} - -#[typetag::serde] -impl RecordBatchExpr for CompositionExpr { - /// Apply the transformations to the `RecordBatch`. - fn apply_transformation(&self, record_batch: RecordBatch) -> Option { - let mut record_batch = record_batch; - - for transformation in self.transformations.iter() { - record_batch = transformation.apply_transformation(record_batch)?; - } - - Some(record_batch) - } -} diff --git a/crates/proof-of-sql/src/sql/transform/composition_expr_test.rs b/crates/proof-of-sql/src/sql/transform/composition_expr_test.rs deleted file mode 100644 index c17fb0e35..000000000 --- a/crates/proof-of-sql/src/sql/transform/composition_expr_test.rs +++ /dev/null @@ -1,47 +0,0 @@ -use crate::{ - record_batch, - sql::transform::{ - test_utility::{composite_result, orders, slice}, - CompositionExpr, - }, -}; -use proof_of_sql_parser::intermediate_ast::OrderByDirection::Desc; - -#[test] -fn we_can_chain_expressions() { - let limit = 2; - let offset = 1; - let data = record_batch!("c" => [-5_i64, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - let mut composition = CompositionExpr::new(orders(&["c"], &[Desc])); - composition.add(slice(limit, offset)); - - let result_expr = composite_result(vec![Box::new(composition)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [1_i64, -5], "a" => ["a", "d"]); - assert_eq!(data, expected_data); -} - -#[test] -fn the_order_that_we_chain_expressions_is_relevant() { - let limit = 2; - let offset = 1; - let data = record_batch!("c" => [-5_i64, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - - let mut composition1 = CompositionExpr::new(orders(&["c"], &[Desc])); - composition1.add(slice(limit, offset)); - let result_expr1 = composite_result(vec![Box::new(composition1)]); - let data1 = result_expr1.transform_results(data.clone()).unwrap(); - - let mut composition2 = CompositionExpr::new(slice(limit, offset)); - composition2.add(orders(&["c"], &[Desc])); - let result_expr2 = composite_result(vec![Box::new(composition2)]); - let data2 = result_expr2.transform_results(data).unwrap(); - - assert_ne!(data1, data2); - - let expected_data1 = record_batch!("c" => [1_i64, -5], "a" => ["a", "d"]); - assert_eq!(data1, expected_data1); - - let expected_data2 = record_batch!("c" => [1_i64, -56], "a" => ["a", "f"]); - assert_eq!(data2, expected_data2); -} diff --git a/crates/proof-of-sql/src/sql/transform/data_frame_expr.rs b/crates/proof-of-sql/src/sql/transform/data_frame_expr.rs deleted file mode 100644 index 25b0140c8..000000000 --- a/crates/proof-of-sql/src/sql/transform/data_frame_expr.rs +++ /dev/null @@ -1,9 +0,0 @@ -use polars::prelude::LazyFrame; -use std::fmt::Debug; - -/// A trait for nodes that can apply transformations to a `LazyFrame`. -#[deprecated = "Use `RecordBatchExpr` instead"] -pub trait DataFrameExpr: Debug + Send + Sync { - /// Apply the transformation to the `LazyFrame` and return the result. - fn lazy_transformation(&self, lazy_frame: LazyFrame, num_input_rows: usize) -> LazyFrame; -} diff --git a/crates/proof-of-sql/src/sql/transform/group_by_expr.rs b/crates/proof-of-sql/src/sql/transform/group_by_expr.rs deleted file mode 100644 index 56cc7d3e7..000000000 --- a/crates/proof-of-sql/src/sql/transform/group_by_expr.rs +++ /dev/null @@ -1,118 +0,0 @@ -#[allow(deprecated)] -use super::DataFrameExpr; -use super::{ToPolarsExpr, INT128_PRECISION, INT128_SCALE}; -use dyn_partial_eq::DynPartialEq; -use polars::prelude::{col, DataType, Expr, GetOutput, LazyFrame, NamedFrom, Series}; -use proof_of_sql_parser::{intermediate_ast::AliasedResultExpr, Identifier}; -use serde::{Deserialize, Serialize}; - -/// A group by expression -#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] -pub struct GroupByExpr { - /// A list of aggregation column expressions - agg_exprs: Vec, - - /// A list of group by column expressions - by_exprs: Vec, -} - -impl GroupByExpr { - /// Create a new group by expression containing the group by and aggregation expressions - pub fn new(by_ids: &[Identifier], aliased_exprs: &[AliasedResultExpr]) -> Self { - let by_exprs = Vec::from_iter(by_ids.iter().map(|id| col(id.as_str()))); - let agg_exprs = Vec::from_iter(aliased_exprs.iter().map(ToPolarsExpr::to_polars_expr)); - assert!(!agg_exprs.is_empty(), "Agg expressions must not be empty"); - assert!( - !by_exprs.is_empty(), - "Group by expressions must not be empty" - ); - - Self { - by_exprs, - agg_exprs, - } - } -} - -super::impl_record_batch_expr_for_data_frame_expr!(GroupByExpr); -#[allow(deprecated)] -impl DataFrameExpr for GroupByExpr { - fn lazy_transformation(&self, lazy_frame: LazyFrame, num_input_rows: usize) -> LazyFrame { - // TODO: polars currently lacks support for min/max aggregation in data frames - // with either zero or one element when a group by operation is applied. - // We remove the group by clause to temporarily work around this limitation. - // Issue created to track progress: https://github.com/pola-rs/polars/issues/11232 - if num_input_rows == 0 { - return lazy_frame.select(&self.agg_exprs).limit(0); - } - - if num_input_rows == 1 { - return lazy_frame.select(&self.agg_exprs); - } - - // Add invalid column aliases to group by expressions so that we can - // exclude them from the final result. - let by_expr_aliases = (0..self.by_exprs.len()) - .map(|pos| "#$".to_owned() + pos.to_string().as_str()) - .collect::>(); - - let by_exprs: Vec<_> = self - .by_exprs - .clone() - .into_iter() - .zip(by_expr_aliases.iter()) - .map(|(expr, alias)| expr.alias(alias.as_str())) - // TODO: remove this mapping once Polars supports decimal columns inside group by - // Issue created to track progress: https://github.com/pola-rs/polars/issues/11078 - .map(group_by_map_to_utf8_if_decimal) - .collect(); - - // We use `groupby_stable` instead of `groupby` - // to avoid non-deterministic results with our tests. - lazy_frame - .group_by_stable(&by_exprs) - .agg(&self.agg_exprs) - .select(&[col("*").exclude(by_expr_aliases)]) - } -} - -pub(crate) fn group_by_map_i128_to_utf8(v: i128) -> String { - // use big end to allow - // skipping leading zeros - v.to_be_bytes() - .into_iter() - // skip leading zeros - .skip_while(|x| *x == 0) - // in the worst case scenario, - // 16 bytes per decimal - // is mapped to 32 bytes per char - // this is not ideal. - // but keeping as it is for now - // since this must be a temporary solution. - .map(char::from) - // Using `Binary` type would consume less space - // But it would be an issue when we try to convert - // the result data frame into a record batch - // since we don't support this data type. - .collect::() -} - -// Polars doesn't support Decimal columns inside group by. -// So we need to remap them to the supported UTF8 type. -fn group_by_map_to_utf8_if_decimal(expr: Expr) -> Expr { - expr.map( - |series| match series.dtype().clone() { - DataType::Decimal(Some(INT128_PRECISION), Some(INT128_SCALE)) => { - let utf8_data: Vec<_> = series - .decimal() - .unwrap() - .into_no_null_iter() - .map(group_by_map_i128_to_utf8) - .collect(); - Ok(Some(Series::new(series.name(), &utf8_data))) - } - _ => Ok(Some(series)), - }, - GetOutput::from_type(DataType::Utf8), - ) -} diff --git a/crates/proof-of-sql/src/sql/transform/group_by_expr_test.rs b/crates/proof-of-sql/src/sql/transform/group_by_expr_test.rs deleted file mode 100644 index 8a3d7922c..000000000 --- a/crates/proof-of-sql/src/sql/transform/group_by_expr_test.rs +++ /dev/null @@ -1,379 +0,0 @@ -use super::group_by_map_i128_to_utf8; -use crate::{ - record_batch, - sql::transform::test_utility::{col, composite_result, groupby, lit}, -}; -use arrow::record_batch::RecordBatch; -use rand::Rng; - -#[test] -fn we_can_transform_batch_using_group_by_with_a_varchar_column() { - let data = record_batch!("a" => ["a", "d", "a", "b"], "b" => [1_i64, -5, 1, 2], "c" => [-1_i128, 0, -1, 3]); - let by_exprs = vec![col("a")]; - let agg_exprs = vec![ - col("a").first().alias("a"), - col("b").first().alias("b"), - col("c").first().alias("c"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = - record_batch!("a" => ["a", "d", "b"], "b" => [1_i64, -5, 2],"c" => [-1_i128, 0, 3]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_group_by_with_a_i64_column() { - let data = record_batch!("a" => ["a", "d", "a", "b"], "b" => [1_i64, -5, 1, 2], "c" => [-1_i128, 0, -1, 3]); - let by_exprs = vec![col("b")]; - let agg_exprs = vec![ - col("a").first().alias("a"), - col("b").first().alias("b"), - col("c").first().alias("c"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = - record_batch!("a" => ["a", "d", "b"], "b" => [1_i64, -5, 2],"c" => [-1_i128, 0, 3]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_group_by_with_a_i128_column() { - let data = record_batch!("a" => ["a", "d", "a", "b"], "b" => [1_i64, -5, 1, 2], "c" => [-1_i128, 0, -1, 3]); - let by_exprs = vec![col("c")]; - let agg_exprs = vec![ - col("a").first().alias("a"), - col("b").first().alias("b"), - col("c").first().alias("c"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = - record_batch!("a" => ["a", "d", "b"], "b" => [1_i64, -5, 2],"c" => [-1_i128, 0, 3]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_the_same_group_bys_with_the_same_alias() { - let data = record_batch!("c" => [1_i64, -5, 7, 7, 2], "a" => ["a", "d", "a", "a", "b"]); - let by_exprs = vec![col("a"), col("a")]; - let result_expr = composite_result(vec![groupby(by_exprs, vec![col("c").sum().alias("c")])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [15_i64, -5, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_different_group_bys_with_different_aliases() { - let data = record_batch!("c" => [1_i64, -5, 7, 7, 2], "a" => ["a", "d", "a", "a", "b"]); - let by_exprs = vec![col("a"), col("c")]; - let result_expr = composite_result(vec![groupby( - by_exprs, - vec![col("a").first().alias("a"), col("c").first().alias("c")], - )]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a" => ["a", "d", "a", "b"], "c" => [1_i64, -5, 7, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_simple_group_by_with_max_aggregation() { - let data = record_batch!("b" => [1_i64, -5, -3, 7, 2], "c" => [1_i128, -5, -3, 1, -3], "a" => ["a", "d", "b", "a", "b"]); - let by_exprs = vec![col("a"), col("c")]; - let agg_exprs = vec![(col("b") + col("c")).max().alias("bc")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("bc" => [8_i128, -10, -1]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_simple_group_by_with_min_aggregation() { - let data = record_batch!("b" => [1_i64, -5, -3, 7, 2], "c" => [1_i128, -5, -3, 1, -3], "a" => ["a", "d", "b", "a", "b"]); - let by_exprs = vec![col("a"), col("c")]; - let agg_exprs = vec![(col("b") * col("c")).min().alias("bc")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("bc" => [1_i128, 25, -6]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_simple_group_by_with_sum_aggregation() { - let data = record_batch!("b" => [1_i64, -5, -3, 7, 2], "c" => [1_i128, -5, -3, 1, -3], "a" => ["a", "d", "b", "a", "b"]); - let by_exprs = vec![col("a"), col("c")]; - let agg_exprs = vec![(col("b") - col("c")).sum().alias("bc")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("bc" => [6_i128, 0, 5]); - assert_eq!(data, expected_data); -} - -#[test] -#[should_panic] -fn sum_aggregation_can_overflow() { - let data = record_batch!("c" => [i64::MAX, 1], "a" => ["a", "a"]); - let by_exprs = vec![col("a")]; - let agg_exprs = vec![col("c").sum().alias("c")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - result_expr.transform_results(data).unwrap(); -} - -#[test] -fn we_can_transform_batch_using_simple_group_by_with_count_aggregation() { - let data = record_batch!("b" => [1_i64, -5, -3, 7, 2], "c" => [1_i128, -5, -3, 1, -3], "a" => ["a", "d", "b", "a", "b"]); - let by_exprs = vec![col("a"), col("c")]; - let agg_exprs = vec![ - col("a").first().alias("a"), - (lit(-53) * col("b") - lit(45) * col("c") + lit(103)) - .count() - .alias("bc"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a" => ["a", "d", "b"], "bc" => [2_i64, 1, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_simple_group_by_with_first_aggregation() { - let data = record_batch!("a" => ["a", "d", "b", "a", "b"]); - let by_exprs = vec![col("a")]; - let agg_exprs = vec![ - col("a").first().alias("a_col"), - col("a").first().alias("a2_col"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a_col" => ["a", "d", "b"], "a2_col" => ["a", "d", "b"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_group_by_with_the_same_name_as_the_aggregation_expression() { - let data = - record_batch!("c" => [1_i64, -5, -3, 7, 2, 1], "a" => ["a", "d", "b", "a", "b", "f"]); - let by_exprs = vec![col("c")]; - let agg_exprs = vec![col("c").min().alias("c")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [1_i64, -5, -3, 7, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_min_aggregation_with_non_numeric_columns() { - let data = - record_batch!("c" => [1_i64, -5, -3, 7, 2, 1], "a" => ["abd", "d", "b", "a", "b", "abc"]); - let by_exprs = vec![col("c")]; - let agg_exprs = vec![col("c").first().alias("c"), col("a").min().alias("a_min")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = - record_batch!("c" => [1_i64, -5, -3, 7, 2], "a_min" => ["abc", "d", "b", "a", "b"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_max_aggregation_with_non_numeric_columns() { - let data = - record_batch!("c" => [1_i64, -5, -3, 7, -5, 1], "a" => ["abd", "a", "b", "a", "aa", "abc"]); - let by_exprs = vec![col("c")]; - let agg_exprs = vec![col("c").first().alias("c"), col("a").max().alias("a_max")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = - record_batch!("c" => [1_i64, -5, -3, 7], "a_max" => ["abd", "aa", "b", "a"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_count_aggregation_with_non_numeric_columns() { - let data = - record_batch!("c" => [1_i64, -5, -3, 7, 2, 1], "a" => ["a", "d", "b", "a", "b", "f"]); - let by_exprs = vec![col("c")]; - let agg_exprs = vec![col("a").count().alias("a_count")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a_count" => [2_i64, 1, 1, 1, 1]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_simple_group_by_with_multiple_aggregations() { - let data = record_batch!("c" => [1_i128, -5, -3, 7, 2], "a" => ["a", "d", "b", "a", "b"]); - let by_exprs = vec![col("a")]; - let agg_exprs = vec![ - col("c").max().alias("c_max"), - col("a").first().alias("a"), - col("c").min().alias("c_min"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c_max" => [7_i128, -5, 2], "a" => ["a", "d", "b"], "c_min" => [1_i128, -5, -3]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_multiple_group_bys_with_multiple_aggregations() { - let data = record_batch!("c" => [1_i64, -5, -3, 7, -3], "a" => ["a", "d", "b", "a", "b"], "d" => [523_i64, -25, 343, -7, 435]); - let by_exprs = vec![col("a"), col("c")]; - let agg_exprs = vec![ - col("a").first().alias("a_group"), - col("d").max().alias("d_max"), - col("c").count().alias("c_count"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a_group" => ["a", "d", "b", "a"], "d_max" => [523_i64, -25, 435, -7], "c_count" => [1_i64, 1, 2, 1]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_different_aliases_associated_with_the_same_group_by_column() { - let data = record_batch!("a" => ["a", "b"], "d" => [523_i64, -25]); - let by_exprs = vec![col("a")]; - let result_expr = composite_result(vec![groupby( - by_exprs, - vec![col("a").first().alias("a1"), col("a").first().alias("a2")], - )]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a1" => ["a", "b"], "a2" => ["a", "b"]); - assert_eq!(data, expected_data); -} - -#[test] -#[should_panic] -fn we_cannot_transform_batch_using_an_empty_group_by_expression() { - let agg_exprs = vec![col("b").max().alias("b")]; - composite_result(vec![groupby(vec![], agg_exprs)]); -} - -#[test] -#[should_panic] -fn we_cannot_transform_batch_using_an_empty_agg_expression() { - let group_bys = vec![col("b")]; - composite_result(vec![groupby(group_bys, vec![])]); -} - -#[test] -fn we_can_transform_batch_using_arithmetic_expressions_in_the_aggregation() { - let data = record_batch!( - "c" => [1_i64, -5, -3, 7, -3], - "a" => ["a", "d", "b", "a", "b"], - "d" => [523_i64, -25, 343, -7, 435] - ); - let by_exprs = vec![col("a")]; - let agg_exprs = vec![(col("d") * col("c")).sum().alias("cd_sum")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("cd_sum" => [474_i64, 125, -2334]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_batch_using_arithmetic_outside_the_aggregation_exprs() { - let data = record_batch!( - "c" => [1_i128, -5, -3, -5, 7, -3], - "d" => [-1_i64, -5, 0, -5, 7, 7] - ); - let by_exprs = vec![col("d"), col("c")]; - let agg_exprs = vec![ - (col("c").first() + col("d").first()).alias("sum_cd1"), - (col("c") + col("d")).sum().alias("sum_cd2"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!( - "sum_cd1" => [0_i128, -10, -3, 14, 4], - "sum_cd2" => [0_i128, -20, -3, 14, 4], - ); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_use_decimal_columns_inside_group_by() { - let nines: i128 = "9".repeat(38).parse::().unwrap(); - let data = record_batch!( - "h" => [-1_i128, 1, -nines, 0, -2, nines, -3, -1, -3, 1, 11], - "j" => [0_i64, 12, 5, 3, -2, -1, 4, 4, 100, 0, 31], - ); - let by_exprs = vec![col("h")]; - let agg_exprs = vec![(col("j") + col("h")).sum().alias("h_sum")]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!( - "h_sum" => [2_i128, 14, -nines + 5, 3, -2 - 2, nines - 1, -6 + 100 + 4, 11 + 31], - ); - assert_eq!(data, expected_data); -} - -#[test] -fn transforming_a_batch_of_size_zero_with_min_max_agg_and_decimal_column_is_fine() { - let data = record_batch!("h" => [-1_i128], "i" => [2_i128], "j" => [2_i128], "k" => [2_i64]); - let empty_batch = RecordBatch::new_empty(data.schema().clone()); - let by_exprs = vec![col("h")]; - let agg_exprs = vec![ - col("h").max().alias("h"), - col("i").min().alias("i"), - col("j").sum().alias("j"), - col("k").count().alias("k"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(empty_batch.clone()).unwrap(); - let expected_data = empty_batch; - assert_eq!(data, expected_data); -} - -#[test] -fn transforming_a_batch_of_size_one_with_min_max_agg_and_decimal_column_is_fine() { - let input_data = - record_batch!("h" => [-1_i128], "i" => [2_i128], "j" => [2_i128], "k" => [2_i128]); - let by_exprs = vec![col("h")]; - let agg_exprs = vec![ - col("h").max().alias("h"), - col("i").min().alias("i"), - col("j").sum().alias("j"), - col("k").count().alias("k"), - ]; - let result_expr = composite_result(vec![groupby(by_exprs, agg_exprs)]); - let data = result_expr.transform_results(input_data.clone()).unwrap(); - let expected_data = - record_batch!("h" => [-1_i128], "i" => [2_i128], "j" => [2_i128], "k" => [1_i64]); - assert_eq!(data, expected_data); -} - -fn validate_group_by_map_i128_to_utf8(s: Vec) { - let expected_len = s.len(); - - // no collision happens - assert_eq!( - expected_len, - s.iter().collect::>().len() - ); - - assert_eq!( - expected_len, - s.into_iter() - .map(group_by_map_i128_to_utf8) - .collect::>() - .len(), - ); -} - -#[test] -fn group_by_with_consecutive_range_doesnt_have_collisions() { - validate_group_by_map_i128_to_utf8((-300000..300000).collect()); -} - -#[test] -fn group_by_with_random_data_doesnt_have_collisions() { - let mut rng = rand::thread_rng(); - let nines = "9".repeat(38).parse::().unwrap(); - validate_group_by_map_i128_to_utf8( - (-300000..300000) - .map(|_| rng.gen_range(-nines..nines + 1)) - .collect(), - ); -} diff --git a/crates/proof-of-sql/src/sql/transform/mod.rs b/crates/proof-of-sql/src/sql/transform/mod.rs deleted file mode 100644 index 6b832da8b..000000000 --- a/crates/proof-of-sql/src/sql/transform/mod.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! This module contains postprocessing for non-provable components. -/// The precision for [ColumnType::INT128] values -pub const INT128_PRECISION: usize = 38; - -/// The scale for [ColumnType::INT128] values -pub const INT128_SCALE: usize = 0; - -mod result_expr; -pub use result_expr::ResultExpr; - -#[cfg(test)] -pub mod test_utility; - -mod composition_expr; -pub use composition_expr::CompositionExpr; - -#[cfg(test)] -pub mod composition_expr_test; - -mod data_frame_expr; -#[allow(deprecated)] -pub(crate) use data_frame_expr::DataFrameExpr; -mod record_batch_expr; -pub(crate) use record_batch_expr::impl_record_batch_expr_for_data_frame_expr; -pub use record_batch_expr::RecordBatchExpr; - -mod order_by_exprs; -pub use order_by_exprs::OrderByExprs; - -#[cfg(test)] -mod order_by_exprs_test; - -#[cfg(test)] -pub(crate) use order_by_exprs::order_by_map_i128_to_utf8; - -mod slice_expr; -pub use slice_expr::SliceExpr; - -#[cfg(test)] -mod slice_expr_test; - -mod select_expr; -pub use select_expr::SelectExpr; - -#[cfg(test)] -mod select_expr_test; - -mod group_by_expr; -#[cfg(test)] -pub(crate) use group_by_expr::group_by_map_i128_to_utf8; -pub use group_by_expr::GroupByExpr; - -#[cfg(test)] -mod group_by_expr_test; - -mod polars_conversions; -pub use polars_conversions::LiteralConversion; - -mod polars_arithmetic; -pub use polars_arithmetic::SafeDivision; -mod to_polars_expr; -pub(crate) use to_polars_expr::ToPolarsExpr; diff --git a/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs b/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs deleted file mode 100644 index 8c2132ec8..000000000 --- a/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs +++ /dev/null @@ -1,95 +0,0 @@ -#[allow(deprecated)] -use super::DataFrameExpr; -use super::{INT128_PRECISION, INT128_SCALE}; -use arrow::datatypes::ArrowNativeTypeOp; -use dyn_partial_eq::DynPartialEq; -use polars::prelude::{col, DataType, Expr, GetOutput, LazyFrame, NamedFrom, Series}; -use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection}; -use serde::{Deserialize, Serialize}; - -/// A node representing a list of `OrderBy` expressions. -#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] -pub struct OrderByExprs { - by_exprs: Vec, -} - -impl OrderByExprs { - /// Create a new `OrderByExprs` node. - pub fn new(by_exprs: Vec) -> Self { - Self { by_exprs } - } -} - -super::impl_record_batch_expr_for_data_frame_expr!(OrderByExprs); -#[allow(deprecated)] -impl DataFrameExpr for OrderByExprs { - /// Sort the `LazyFrame` by the `OrderBy` expressions. - fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame { - assert!(!self.by_exprs.is_empty()); - - let maintain_order = true; - let nulls_last = false; - let reverse: Vec<_> = self - .by_exprs - .iter() - .map(|v| v.direction == OrderByDirection::Desc) - .collect(); - let by_column: Vec<_> = self - .by_exprs - .iter() - .map(|v| order_by_map_to_utf8_if_decimal(col(v.expr.name()))) - .collect(); - - lazy_frame.sort_by_exprs(by_column, reverse, nulls_last, maintain_order) - } -} - -/// Converts a signed 128-bit integer to a UTF-8 string that preserves -/// the order of the original integer array when sorted. -/// -/// For any given two integers `a` and `b` we have: -/// * `a < b` if and only if `map_i128_to_utf8(a) < map_i128_to_utf8(b)`. -/// * `a == b` if and only if `map_i128_to_utf8(a) == map_i128_to_utf8(b)`. -/// * `a > b` if and only if `map_i128_to_utf8(a) > map_i128_to_utf8(b)`. -pub(crate) fn order_by_map_i128_to_utf8(v: i128) -> String { - let is_neg = v.is_negative() as u8; - v.abs() - // use big-endian order to allow skipping the leading zero bytes - .to_be_bytes() - .into_iter() - // skip the leading zero bytes to save space - .skip_while(|c| c.is_zero()) - .collect::>() - .into_iter() - // reverse back to little-endian order - .rev() - // append a byte that indicates the number of leading zero bits - // this is necessary because "12" is lexicographically smaller than "9" - // which is not the case for the original integer array as 9 < 12. - // so we append the number of leading zero bits to guarantee that "{byte}9" < "{byte}12" - .chain(std::iter::once((255 - v.abs().leading_zeros()) as u8 + 1)) - // transform the bytes of negative values so that smaller negative numbers converted - // to strings can appear before larger negative numbers converted to strings - .map(|c| (255 - c) * is_neg + c * (1 - is_neg)) - .map(char::from) - .rev() - .collect() -} - -// Polars doesn't support Decimal columns inside order by. -// So we need to remap them to the supported UTF8 type. -fn order_by_map_to_utf8_if_decimal(expr: Expr) -> Expr { - expr.map( - |series| match series.dtype().clone() { - DataType::Decimal(Some(INT128_PRECISION), Some(INT128_SCALE)) => { - let i128_data = series.decimal().unwrap().into_no_null_iter(); - // TODO: remove this mapping once Polars supports decimal columns inside order by - // Issue created to track progress: https://github.com/pola-rs/polars/issues/11079 - let utf8_data = i128_data.map(order_by_map_i128_to_utf8).collect::>(); - Ok(Some(Series::new(series.name(), utf8_data))) - } - _ => Ok(Some(series)), - }, - GetOutput::from_type(DataType::Utf8), - ) -} diff --git a/crates/proof-of-sql/src/sql/transform/order_by_exprs_test.rs b/crates/proof-of-sql/src/sql/transform/order_by_exprs_test.rs deleted file mode 100644 index 16b797e95..000000000 --- a/crates/proof-of-sql/src/sql/transform/order_by_exprs_test.rs +++ /dev/null @@ -1,231 +0,0 @@ -use super::order_by_map_i128_to_utf8; -use crate::{ - base::database::ToArrow, - record_batch, - sql::transform::test_utility::{composite_result, orders}, -}; -use proof_of_sql_parser::intermediate_ast::OrderByDirection::{Asc, Desc}; -use rand::{distributions::uniform::SampleUniform, seq::SliceRandom, Rng}; - -#[test] -fn we_can_transform_a_result_using_a_single_order_by_in_ascending_direction() { - let data = record_batch!("c" => [1_i64, -5, 2], "a" => ["a", "d", "b"]); - let result_expr = composite_result(vec![orders(&["a"], &[Asc])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [1_i64, 2, -5], "a" => ["a", "b", "d"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_a_result_using_a_single_order_by_in_ascending_direction_with_i128_data() { - let data = record_batch!("c" => [1_i128, -5, 2], "a" => ["a", "d", "b"]); - let result_expr = composite_result(vec![orders(&["a"], &[Asc])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [1_i128, 2, -5], "a" => ["a", "b", "d"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_a_result_using_a_single_order_by_in_descending_direction() { - let data = record_batch!("c" => [1_i64, -5, 2], "a" => ["a", "d", "b"]); - let result_expr = composite_result(vec![orders(&["c"], &[Desc])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [2_i64, 1, -5], "a" => ["b", "a", "d"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_a_result_ordering_by_the_first_column_then_the_second_column() { - let data = record_batch!( - "a" => [123_i64, 342, -234, 777, 123, 34], - "d" => ["alfa", "beta", "abc", "f", "kl", "f"] - ); - let result_expr = composite_result(vec![orders(&["a", "d"], &[Desc, Desc])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!( - "a" => [777_i64, 342, 123, 123, 34, -234], - "d" => ["f", "beta", "kl", "alfa", "f", "abc"] - ); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_transform_a_result_ordering_by_the_second_column_then_the_first_column() { - let data = record_batch!( - "a" => [123_i64, 342, -234, 777, 123, 34], - "d" => ["alfa", "beta", "abc", "f", "kl", "f"] - ); - let result_expr = composite_result(vec![orders(&["d", "a"], &[Desc, Asc])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!( - "a" => [123_i64, 34, 777, 342, 123, -234], - "d" => ["kl", "f", "f", "beta", "alfa", "abc", ] - ); - assert_eq!(data, expected_data); -} - -#[test] -fn order_by_preserve_order_with_equal_elements() { - let data = record_batch!("c" => [1_i64, -5, 1, 2], "a" => ["f", "d", "a", "b"]); - let result_expr = composite_result(vec![orders(&["c"], &[Desc])]); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("c" => [2_i64, 1, 1, -5], "a" => ["b", "f", "a", "d"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_use_decimal_columns_inside_order_by_in_desc_order() { - let nines = "9".repeat(38).parse::().unwrap(); - let s = [ - -1_i128, 1, -nines, -nines, 0, -2, nines, -3, nines, -1, -3, 1, -nines, 11, -nines, - ]; - - let data = record_batch!("h" => s, "j" => s); - let result_expr = composite_result(vec![orders(&["j", "h"], &[Desc, Asc])]); - let data = result_expr.transform_results(data).unwrap(); - - let mut sorted_s = s; - sorted_s.sort_unstable(); - let reverse_sorted_s = sorted_s.into_iter().rev().collect::>(); - - let expected_data: arrow::record_batch::RecordBatch = record_batch!( - "h" => reverse_sorted_s.clone(), - "j" => reverse_sorted_s, - ); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_use_decimal_columns_inside_order_by_in_asc_order() { - let nines = "9".repeat(38).parse::().unwrap(); - let s = [ - -1_i128, 1, -nines, -nines, 0, -2, nines, -3, nines, -1, -3, 1, -nines, 11, -nines, - ]; - - let data = record_batch!("h" => s, "j" => s); - let result_expr = composite_result(vec![orders(&["j", "h"], &[Asc, Desc])]); - let data = result_expr.transform_results(data).unwrap(); - - let mut sorted_s = s; - sorted_s.sort_unstable(); - - let expected_data: arrow::record_batch::RecordBatch = record_batch!( - "h" => sorted_s.clone(), - "j" => sorted_s, - ); - assert_eq!(data, expected_data); -} - -fn validate_integer_columns_with_order_by(low: T, high: T, range: Vec) -where - T: SampleUniform + Clone + Ord, - Vec: ToArrow, -{ - let mut rng = rand::thread_rng(); - let data: Vec = range - .iter() - .map(|_| rng.gen_range(low.clone()..high.clone())) - .chain(range.clone()) - .collect(); - - let (shuffled_data, sorted_data) = { - let mut shuffled_s = data.clone(); - shuffled_s.shuffle(&mut rng); - let mut sorted_s = data.clone(); - sorted_s.sort_unstable(); - (shuffled_s, sorted_s) - }; - - let data = record_batch!("h" => shuffled_data); - let expected_data = record_batch!("h" => sorted_data); - let result_expr = composite_result(vec![orders(&["h"], &[Asc])]); - let data = result_expr.transform_results(data).unwrap(); - assert_eq!(data, expected_data); -} - -#[test] -fn order_by_with_random_i64_data() { - validate_integer_columns_with_order_by::(i64::MIN, i64::MAX, (-300000..300000).collect()); -} - -#[test] -fn order_by_with_random_i128_data() { - let nines = "9".repeat(38).parse::().unwrap(); - validate_integer_columns_with_order_by::(-nines, nines + 1, (-300000..300000).collect()); -} - -#[test] -fn map_i128_to_utf8_not_equals_is_valid() { - assert!( - order_by_map_i128_to_utf8(-99999999999999999999999999999999999999) - < order_by_map_i128_to_utf8(124) - ); - assert!(order_by_map_i128_to_utf8(-121) < order_by_map_i128_to_utf8(122)); - assert!(order_by_map_i128_to_utf8(-123) < order_by_map_i128_to_utf8(-122)); - assert!(order_by_map_i128_to_utf8(-123) < order_by_map_i128_to_utf8(124)); - assert!(order_by_map_i128_to_utf8(-123) < order_by_map_i128_to_utf8(0)); - assert!(order_by_map_i128_to_utf8(-1) < order_by_map_i128_to_utf8(0)); - assert!(order_by_map_i128_to_utf8(0) < order_by_map_i128_to_utf8(1)); - assert!(order_by_map_i128_to_utf8(0) < order_by_map_i128_to_utf8(124)); - assert!( - order_by_map_i128_to_utf8(124) - < order_by_map_i128_to_utf8(99999999999999999999999999999999999999) - ); - assert!( - order_by_map_i128_to_utf8(-99999999999999999999999999999999999999) - < order_by_map_i128_to_utf8(99999999999999999999999999999999999999) - ); -} - -fn validate_order_by_map_i128_to_utf8_with_array(s: Vec) { - let mut sorted_s: Vec<_> = s.clone(); - sorted_s.sort_unstable(); - - let mut utf8_sorted_s: Vec<_> = s.iter().map(|v| order_by_map_i128_to_utf8(*v)).collect(); - utf8_sorted_s.sort_unstable(); - - // ordering is preserved - assert_eq!( - sorted_s - .iter() - .map(|&v| order_by_map_i128_to_utf8(v)) - .collect::>(), - utf8_sorted_s - ); - - // no collision happens - assert_eq!( - sorted_s.iter().collect::>().len(), - utf8_sorted_s - .iter() - .collect::>() - .len(), - ); -} - -#[test] -fn order_by_with_consecutive_range_preserves_ordering() { - validate_order_by_map_i128_to_utf8_with_array((-300000..300000).collect()); -} - -#[test] -fn order_by_with_random_data_preserves_ordering() { - let mut rng = rand::thread_rng(); - let nines = "9".repeat(38).parse::().unwrap(); - validate_order_by_map_i128_to_utf8_with_array( - (-300000..300000) - .map(|_| rng.gen_range(-nines..nines + 1)) - .collect(), - ); -} - -#[test] -#[should_panic] -fn order_by_panics_with_min_out_of_range_value() { - order_by_map_i128_to_utf8(i128::MIN); -} - -#[test] -fn order_by_do_not_panic_with_max_out_of_range_value() { - order_by_map_i128_to_utf8(i128::MAX); -} diff --git a/crates/proof-of-sql/src/sql/transform/polars_arithmetic.rs b/crates/proof-of-sql/src/sql/transform/polars_arithmetic.rs deleted file mode 100644 index 0f0aafbe7..000000000 --- a/crates/proof-of-sql/src/sql/transform/polars_arithmetic.rs +++ /dev/null @@ -1,510 +0,0 @@ -use polars::{ - error::ErrString, - prelude::{DataType, Expr, GetOutput, PolarsError, PolarsResult, Series}, -}; - -fn series_to_i64_slice(series: &Series) -> &[i64] { - series - .i64() - .unwrap() - .cont_slice() - .expect("slice cannot contain nulls") -} - -fn series_to_i128_slice(series: &Series) -> &[i128] { - series - .decimal() - .unwrap() - .cont_slice() - .expect("slice cannot contain nulls") -} - -fn has_zero_in_series(series: &Series) -> bool { - match series.dtype().clone() { - DataType::Decimal(Some(_), Some(_)) => series_to_i128_slice(series).iter().any(|&v| v == 0), - DataType::Int64 => series_to_i64_slice(series).iter().any(|&v| v == 0), - _ => false, - } -} - -fn will_div_overflow(num: &Series, den: &Series) -> bool { - match (num.dtype(), den.dtype()) { - (DataType::Int64, DataType::Int64) => { - let num = series_to_i64_slice(num); - let den = series_to_i64_slice(den); - - num.iter() - .zip(den.iter()) - .any(|(n, d)| *n == i64::MIN && *d == -1) - } - _ => false, - } -} - -fn checked_div(series: &mut [Series]) -> PolarsResult> { - let [num, den] = [&series[0], &series[1]]; - - if has_zero_in_series(den) { - return Err(PolarsError::InvalidOperation(ErrString::from( - "division by zero is not allowed", - ))); - } - - if will_div_overflow(num, den) { - return Err(PolarsError::InvalidOperation(ErrString::from( - "attempt to divide i64 with overflow", - ))); - } - - Ok(Some(num / den)) -} - -/// Trait that provides a safe division operation for polars expressions. -pub trait SafeDivision { - /// Division operation that returns an error if the denominator is zero or if the division will overflow. - fn checked_div(self, rhs: Expr) -> Expr; -} - -impl SafeDivision for Expr { - fn checked_div(self, rhs: Expr) -> Expr { - self.map_many(checked_div, &[rhs], GetOutput::default()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - record_batch as batch, - sql::transform::{polars_conversions::LiteralConversion, test_utility::select, ResultExpr}, - }; - use polars::prelude::col; - use rand::{distributions::Uniform, Rng}; - - const MAX_I64: i128 = i64::MAX as i128; - const MIN_I64: i128 = i64::MIN as i128; - const MAX_DECIMAL: i128 = 10_i128.pow(38) - 1; - const MIN_DECIMAL: i128 = -(10_i128.pow(38) - 1); - - macro_rules! test_expr { - ($expr:expr, $expected:expr) => { - let data = batch!("" => [0_i64]); - let result = ResultExpr::new(select(&[$expr.alias("res")])).transform_results(data).unwrap(); - assert_eq!(result, $expected); - }; - ($expr:expr, $expected:expr, $data:expr) => { - assert_eq!(ResultExpr::new(select(&[$expr.alias("res")])).transform_results($data).unwrap(), $expected); - }; - } - - macro_rules! safe_arithmetic { - ($op:expr, $x:expr, $y:expr, $x_e:expr, $y_e:expr) => { - let data = batch!("x" => [$x], "y" => [$y]); - - match $op { - "add" => { - if $x.checked_add($y).is_some() && ($x + $y) <= MAX_DECIMAL && ($x + $y) >= MIN_DECIMAL { - test_expr!($x_e + $y_e, batch!("res" => [$x + $y]), data); - } - } - "sub" => { - if $x.checked_sub($y).is_some() && ($x - $y) <= MAX_DECIMAL && ($x - $y) >= MIN_DECIMAL { - test_expr!($x.to_lit() - $y.to_lit(), batch!("res" => [$x - $y]), data); - } - } - "mul" => { - if $x.checked_mul($y).is_some() && ($x * $y) <= MAX_DECIMAL && ($x * $y) >= MIN_DECIMAL { - test_expr!($x.to_lit() * $y.to_lit(), batch!("res" => [$x * $y]), data); - } - } - "div" => { - if $y != 0 { - test_expr!($x.to_lit().checked_div($y.to_lit()), batch!("res" => [$x / $y]), data); - } - } - _ => panic!("Invalid operation"), - } - }; - } - - macro_rules! batch_execute_test { - ($batch:expr) => { - for [x, y] in $batch { - for [x, y] in [[x, y], [y, x]] { - for op in ["add", "sub", "mul", "div"].into_iter() { - safe_arithmetic!(op, x, y, x.to_lit(), y.to_lit()); - safe_arithmetic!(op, x, y, x.to_lit(), col("y")); - safe_arithmetic!(op, x, y, col("x"), y.to_lit()); - safe_arithmetic!(op, x, y, col("x"), col("y")); - - /////////////////////////////////////////////////////////////////////////////// - // TODO: Address Precision Loss between decimal and i64 columns - /////////////////////////////////////////////////////////////////////////////// - // The following tests encounter issues due to the automatic - // casting of i64 to f64 in Polars, resulting in precision loss. - // A fix has been proposed in this pull request: - // https://github.com/pola-rs/polars/pull/11166. - // - // However, since the merge may take time, - // I plan to implement a workaround in a subsequent pull request. - // This workaround involves explicit casting to decimal(38, 0) - // when i64 columns are utilized. This will work. - /////////////////////////////////////////////////////////////////////////////// - // if x >= i64::MIN as i128 && x <= i64::MAX as i128 { - // safe_arithmetic!(op, x, y, col("x").cast(DataType::Int64), col("y")); - // safe_arithmetic!(op, x, y, col("x").cast(DataType::Int64), y.to_lit()); - // } - /////////////////////////////////////////////////////////////////////////////// - // if i64::try_from(x).is_ok() { - // safe_arithmetic!(op, x, y, col("x").cast(DataType::Int64), y.to_lit()); - // safe_arithmetic!(op, x, y, col("x").cast(DataType::Int64), col("y")); - // } - /////////////////////////////////////////////////////////////////////////////// - // if i64::try_from(y).is_ok() { - // safe_arithmetic!(op, x, y, col("x"), col("y").cast(DataType::Int64)); - // safe_arithmetic!(op, x, y, x.to_lit(), col("y").cast(DataType::Int64)); - // } - /////////////////////////////////////////////////////////////////////////////// - // if i64::try_from(x).is_ok() && i64::try_from(y).is_ok() { - // safe_arithmetic!(op, x, y, col("x").cast(DataType::Int64), col("y").cast(DataType::Int64)); - // safe_arithmetic!(op, x, y, x.to_lit().cast(DataType::Int64), y.to_lit().cast(DataType::Int64)); - // } - /////////////////////////////////////////////////////////////////////////////// - } - } - } - }; - } - - #[test] - #[should_panic] - fn conversion_to_literal_with_i128_min_overflows() { - test_expr! {i128::MIN.to_lit(), batch!("res" => [i128::MIN])}; - } - - #[test] - #[should_panic] - fn conversion_to_literal_with_i128_max_overflows() { - test_expr! {i128::MAX.to_lit(), batch!("res" => [i128::MAX])}; - } - - #[test] - #[should_panic] - fn conversion_to_lit_with_i128_bigger_than_max_decimal_overflows() { - test_expr! {(MAX_DECIMAL + 1).to_lit(), batch!("res" => [(MAX_DECIMAL + 1)])}; - } - - #[test] - #[should_panic] - fn conversion_to_literal_with_i128_smaller_than_min_decimal_overflows() { - test_expr! {(MIN_DECIMAL - 1).to_lit(), batch!("res" => [(MIN_DECIMAL - 1)])}; - } - - #[test] - #[should_panic] - fn conversion_to_literal_with_i128_bigger_than_max_decimal_overflows() { - test_expr! {(MAX_DECIMAL + 1).to_lit(), batch!("res" => [(MAX_DECIMAL + 1)])}; - } - - #[test] - #[should_panic] - fn add_two_i128_literals_overflowing_will_panic() { - test_expr!( - MAX_DECIMAL.to_lit() + (1_i128).to_lit(), - batch!("res" => [MAX_DECIMAL + 1]) - ); - } - - #[test] - #[should_panic] - fn add_literal_i128_and_column_overflowing_will_panic() { - test_expr!( - MAX_DECIMAL.to_lit() + col("x"), - batch!("res" => [MAX_DECIMAL + 1]), - batch!("x" => [1_i128]) - ); - } - - #[test] - #[should_panic] - fn add_two_i128_and_columns_overflowing_will_panic() { - test_expr!( - col("y") + col("x"), - batch!("res" => [MAX_DECIMAL + 1]), - batch!("x" => [1_i128], "y" => [MAX_DECIMAL]) - ); - } - - #[test] - fn sub_two_i128_literals_can_overflow_but_may_not_panic() { - test_expr!( - MIN_DECIMAL.to_lit() - (MIN_DECIMAL / 10).to_lit(), - batch!("res" => [MIN_DECIMAL - (MIN_DECIMAL/10)]) - ); - } - - #[test] - #[should_panic] - fn mul_two_i128_literals_overflows() { - test_expr!( - 10_i128.to_lit() * (10_i128.pow(37)).to_lit(), - batch!("res" => [MAX_DECIMAL + 1]) - ); - } - - #[test] - #[should_panic] - fn mul_i128_column_and_literal_overflows() { - test_expr!( - col("x") * 10_i128.to_lit(), - batch!("res" => [MAX_DECIMAL + 1]), - batch!("x" => [10_i128.pow(37)]) - ); - } - - #[test] - #[should_panic] - fn mul_i128_literal_and_column_overflows() { - test_expr!( - 10_i128.to_lit() * col("x"), - batch!("res" => [MAX_DECIMAL + 1]), - batch!("x" => [10_i128.pow(37)]) - ); - } - - #[test] - #[should_panic] - fn mul_two_i128_columns_overflows() { - test_expr!( - col("x") * col("y"), - batch!("res" => [MAX_DECIMAL + 1]), - batch!("x" => [10_i128.pow(37)], "y" => [10_i128]) - ); - } - - #[test] - fn we_can_execute_multiple_arithmetic_operations_between_expressions() { - batch_execute_test!([ - [0, -10], - [MAX_DECIMAL, -1], - [MIN_DECIMAL, 1], - [MAX_DECIMAL, MIN_DECIMAL], - [i64::MAX as i128, i64::MAX as i128], - [i64::MIN as i128, i64::MIN as i128], - [i64::MIN as i128, i64::MAX as i128], - [-4654825170126467706_i128, 4654825170126467706_i128], - ]); - } - - #[test] - fn we_can_execute_multiple_random_arithmetic_operations_between_expressions() { - const NUM_RANDOM_VALUES: usize = 1000; - let mut rng = rand::thread_rng(); - - let rand_samples: Vec<_> = (0..NUM_RANDOM_VALUES) - .flat_map(|_| { - let lit1d = rng.sample(Uniform::new(MIN_DECIMAL, MAX_DECIMAL + 1)); - let lit2d = rng.sample(Uniform::new(MIN_DECIMAL, MAX_DECIMAL + 1)); - - let lit1i = rng.sample(Uniform::new(MIN_I64, MAX_I64 + 1)); - let lit2i = rng.sample(Uniform::new(MIN_I64, MAX_I64 + 1)); - - [[lit1i, lit2i], [lit1d, lit2d]] - }) - .collect(); - - batch_execute_test!(rand_samples); - } - - #[test] - #[should_panic] - fn valid_i128_with_i64_sub_will_incorrectly_overflow() { - let v = -4654825170126467706_i64; - test_expr!( - col("y") - col("x").cast(DataType::Int64), - batch!("res" => [0_i128]), - batch!("y" => [v as i128], "x" => [v as i128]) - ); - } - - #[test] - #[should_panic] - fn division_with_zero_i64_numerator_zero_i64_denominator_will_error() { - test_expr!( - col("i1").checked_div(col("i")), - batch!("res" => [0_i64]), - batch!("i1" => [0_i64], "i" => [0_i64]) - ); - } - - #[test] - #[should_panic] - fn division_with_non_zero_i64_numerator_zero_i64_denominator_will_error() { - test_expr!( - col("i1").checked_div(col("i")), - batch!("res" => [0_i64]), - batch!("i1" => [1_i64], "i" => [0_i64]) - ); - } - - #[test] - #[should_panic] - fn division_with_non_zero_i128_numerator_zero_i128_denominator_will_error() { - test_expr!( - col("d1").checked_div(col("d")), - batch!("res" => [0_i128]), - batch!("d1" => [1_i128], "d" => [0_i128]) - ); - } - - #[test] - #[should_panic] - fn division_with_zero_i128_numerator_zero_i128_denominator_will_error() { - test_expr!( - col("d1").checked_div(col("d")), - batch!("res" => [0_i128]), - batch!("d1" => [0_i128], "d" => [0_i128]) - ); - } - - #[test] - #[should_panic] - fn division_with_non_zero_i64_numerator_zero_i128_denominator_will_error() { - test_expr!( - col("i").checked_div(col("d")), - batch!("res" => [0_i128]), - batch!("i" => [1_i64], "d" => [0_i128]) - ); - } - - #[test] - #[should_panic] - fn division_with_zero_i64_numerator_zero_i128_denominator_will_error() { - test_expr!( - col("i").checked_div(col("d")), - batch!("res" => [0_i128]), - batch!("i" => [0_i64], "d" => [0_i128]) - ); - } - - #[test] - #[should_panic] - fn division_with_non_zero_i128_numerator_zero_i64_denominator_will_error() { - test_expr!( - col("d").checked_div(col("i")), - batch!("res" => [0_i128]), - batch!("i" => [0_i64], "d" => [1_i128]) - ); - } - - #[test] - #[should_panic] - fn polars_will_panic_with_i64_numerator_and_denominator_and_division_overflowing_even_in_release_mode( - ) { - test_expr!( - col("i1").checked_div(col("i2")), - batch!("res" => [MIN_I64 as i64]), - batch!("i1" => [MIN_I64 as i64], - "i2" => [-1_i64]) - ); - } - - #[test] - fn division_with_different_values_of_numerator_and_denominator_is_valid() { - let range = (-31..31).chain([ - MAX_I64, - MAX_I64, - MAX_DECIMAL, - MIN_DECIMAL, - MAX_I64 - 1, - MIN_I64 + 1, - MAX_DECIMAL - 1, - MIN_DECIMAL + 1, - MAX_I64 / 10, - MIN_I64 / 10, - MAX_DECIMAL / 10, - MIN_DECIMAL / 10, - ]); - - for num in range.clone() { - for den in range.clone() { - if den != 0 { - if (MIN_I64..=MAX_I64).contains(&num) && (MIN_I64..=MAX_I64).contains(&den) { - let (div_res, will_overflow) = (num as i64).overflowing_div(den as i64); - - if !will_overflow { - test_expr!( - col("num").checked_div(col("den")), - batch!("res" => [div_res]), - batch!("num" => [num as i64], - "den" => [den as i64]) - ); - } - } - - if (MIN_I64..=MAX_I64).contains(&num) { - test_expr!( - col("num") - .cast(DataType::Decimal(Some(38), Some(0))) - .checked_div(col("den")), - batch!("res" => [num / den]), - batch!("num" => [num as i64], - "den" => [den]) - ); - } - - if (MIN_I64..=MAX_I64).contains(&den) { - test_expr!( - col("num") - .checked_div(col("den").cast(DataType::Decimal(Some(38), Some(0)))), - batch!("res" => [num / den]), - batch!("num" => [num], - "den" => [den as i64]) - ); - } - - test_expr!( - col("num").checked_div(col("den")), - batch!("res" => [num / den]), - batch!("num" => [num], - "den" => [den]) - ); - } - } - } - } - - #[test] - fn we_can_use_compound_arithmetic_expressions() { - let range = (-31..31).chain([ - MIN_I64, - MAX_I64, - MAX_I64 - 1, - MIN_I64 + 1, - MAX_I64, - MIN_I64, - MAX_DECIMAL / 1000, - MIN_DECIMAL / 1000, - ]); - - for v1 in range.clone() { - for v2 in range.clone() { - let expr = 5_i64.to_lit() - * ((2_i64.to_lit() + col("v1") * 3_i64.to_lit() - col("v1")) - .checked_div(col("v2") + (-2_i64).to_lit() * col("v2"))) - + 77_i64.to_lit(); - - let num = 2_i128 + v1 * 3 - v1; - let den = v2 - 2 * v2; - - if den != 0 { - test_expr!( - expr, - batch!("res" => [5 * (num / den) + 77]), - batch!("v1" => [v1], "v2" => [v2]) - ); - } - } - } - } -} diff --git a/crates/proof-of-sql/src/sql/transform/polars_conversions.rs b/crates/proof-of-sql/src/sql/transform/polars_conversions.rs deleted file mode 100644 index 5c7db486c..000000000 --- a/crates/proof-of-sql/src/sql/transform/polars_conversions.rs +++ /dev/null @@ -1,107 +0,0 @@ -use super::{INT128_PRECISION, INT128_SCALE}; -use polars::prelude::{DataType, Expr, Literal, LiteralValue, Series}; - -/// Convert a Rust type to a Polars `Expr` type. -pub trait LiteralConversion { - /// Convert the Rust type to a Polars `Expr` type. - fn to_lit(&self) -> Expr; -} - -impl LiteralConversion for bool { - fn to_lit(&self) -> Expr { - Expr::Literal(LiteralValue::Boolean(*self)) - } -} - -impl LiteralConversion for i16 { - fn to_lit(&self) -> Expr { - Expr::Literal(LiteralValue::Int16(*self)) - } -} - -impl LiteralConversion for i32 { - fn to_lit(&self) -> Expr { - Expr::Literal(LiteralValue::Int32(*self)) - } -} - -impl LiteralConversion for i64 { - fn to_lit(&self) -> Expr { - Expr::Literal(LiteralValue::Int64(*self)) - } -} - -impl LiteralConversion for i128 { - fn to_lit(&self) -> Expr { - let s = [self.abs().to_string()].into_iter().collect::(); - let l = s.lit().cast(DataType::Decimal( - Some(INT128_PRECISION), - Some(INT128_SCALE), - )); - - if self.is_negative() { - [-1].into_iter().collect::().lit() * l - } else { - l - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - record_batch as batch, - sql::transform::{test_utility::select, ResultExpr}, - }; - - const MAX_I64: i128 = i64::MAX as i128; - const MIN_I64: i128 = i64::MIN as i128; - const MAX_DECIMAL: i128 = 10_i128.pow(38) - 1; - const MIN_DECIMAL: i128 = -(10_i128.pow(38) - 1); - - macro_rules! test_expr { - ($expr:expr, $expected:expr) => { - let data = batch!("" => [0_i64]); - let result = ResultExpr::new(select(&[$expr])).transform_results(data).unwrap(); - assert_eq!(result, $expected); - }; - ($expr:expr, $expected:expr, $data:expr) => { - assert_eq!(ResultExpr::new(select(&[$expr.alias("")])).transform_results($data).unwrap(), $expected); - }; - } - - #[test] - fn boolean_can_be_properly_converted_to_lit() { - test_expr! {true.to_lit(), batch!("literal" => [true])}; - test_expr! {false.to_lit(), batch!("literal" => [false])}; - } - - #[test] - fn i64_can_be_properly_converted_to_lit() { - test_expr! {1_i64.to_lit(), batch!("literal" => [1_i64])}; - test_expr! {0_i64.to_lit(), batch!("literal" => [0_i64])}; - test_expr! {(-1_i64).to_lit(), batch!("literal" => [-1_i64])}; - test_expr!(i64::MAX.to_lit(), batch!("literal" => [i64::MAX])); - test_expr!(i64::MIN.to_lit(), batch!("literal" => [i64::MIN])); - (-3000_i64..3000_i64).for_each(|i| { - test_expr! {i.to_lit(), batch!("literal" => [i])}; - }); - } - - #[test] - fn i128_can_be_properly_converted_to_lit() { - test_expr! {1_i128.to_lit(), batch!("" => [1_i128])}; - test_expr! {0_i128.to_lit(), batch!("" => [0_i128])}; - test_expr! {(-1_i128).to_lit(), batch!("" => [-1_i128])}; - test_expr! {MAX_DECIMAL.to_lit(), batch!("" => [MAX_DECIMAL])}; - test_expr! {(MIN_DECIMAL).to_lit(), batch!("" => [MIN_DECIMAL])}; - test_expr! {(MIN_DECIMAL + 1).to_lit(), batch!("" => [MIN_DECIMAL + 1])}; - test_expr! {(MAX_DECIMAL - 1).to_lit(), batch!("" => [MAX_DECIMAL - 1])}; - test_expr!(MAX_I64.to_lit(), batch!("" => [i64::MAX as i128])); - test_expr!(MIN_I64.to_lit(), batch!("" => [i64::MIN as i128])); - (-3000_i128..3000_i128).for_each(|i| { - test_expr! {i.to_lit(), batch!("" => [i])}; - }); - } -} diff --git a/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs b/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs deleted file mode 100644 index 0e59e2aea..000000000 --- a/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs +++ /dev/null @@ -1,32 +0,0 @@ -use arrow::record_batch::RecordBatch; -use dyn_partial_eq::dyn_partial_eq; -use std::fmt::Debug; - -/// A trait for nodes that can apply transformations to a `RecordBatch`. -#[typetag::serde(tag = "type")] -#[dyn_partial_eq] -pub trait RecordBatchExpr: Debug + Send + Sync { - /// Apply the transformation to the `RecordBatch` and return the result. - fn apply_transformation(&self, record_batch: RecordBatch) -> Option; -} - -macro_rules! impl_record_batch_expr_for_data_frame_expr { - ($t:ty) => { - #[typetag::serde] - impl crate::sql::transform::record_batch_expr::RecordBatchExpr for $t { - fn apply_transformation( - &self, - record_batch: arrow::record_batch::RecordBatch, - ) -> Option { - let (lazy_frame, num_input_rows) = - crate::sql::transform::result_expr::record_batch_to_lazy_frame(record_batch)?; - #[allow(deprecated)] - crate::sql::transform::result_expr::lazy_frame_to_record_batch( - self.lazy_transformation(lazy_frame, num_input_rows), - ) - } - } - }; -} - -pub(crate) use impl_record_batch_expr_for_data_frame_expr; diff --git a/crates/proof-of-sql/src/sql/transform/result_expr.rs b/crates/proof-of-sql/src/sql/transform/result_expr.rs deleted file mode 100644 index 244a0afec..000000000 --- a/crates/proof-of-sql/src/sql/transform/result_expr.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::{ - base::database::{dataframe_to_record_batch, record_batch_to_dataframe}, - sql::transform::RecordBatchExpr, -}; -use arrow::record_batch::RecordBatch; -use dyn_partial_eq::DynPartialEq; -use polars::prelude::{IntoLazy, LazyFrame}; -use serde::{Deserialize, Serialize}; - -/// The result expression is used to transform the results of a query -/// -/// Note: both the `transformation` and `result_schema` are -/// mutually exclusive operations. So they must not be set at the same time. -#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] -pub struct ResultExpr { - transformation: Box, -} - -impl ResultExpr { - /// Create a new `ResultExpr` node with the provided transformation to be applied to the input record batch. - pub fn new(transformation: Box) -> Self { - Self { transformation } - } -} - -pub(super) fn record_batch_to_lazy_frame(result_batch: RecordBatch) -> Option<(LazyFrame, usize)> { - let num_input_rows = result_batch.num_rows(); - let df = record_batch_to_dataframe(result_batch)?; - Some((df.lazy(), num_input_rows)) -} -pub(super) fn lazy_frame_to_record_batch(lazy_frame: LazyFrame) -> Option { - // We're currently excluding NULLs in post-processing due to a lack of - // prover support, aiming to avoid future complexities. - // The drawback is that users won't get NULL results in aggregations on empty data. - // For example, the query `SELECT MAX(i), COUNT(i), SUM(i), MIN(i) FROM table WHERE s = 'nonexist'` - // will now omit the entire row (before, it would return `null, 0, 0, null`). - // This choice is acceptable, as `SELECT MAX(i), COUNT(i), SUM(i) FROM table WHERE s = 'nonexist' GROUP BY f` - // has the same outcome. - // - // TODO: Revisit if we add NULL support to the prover. - let lazy_frame = lazy_frame.drop_nulls(None); - - dataframe_to_record_batch(lazy_frame.collect().ok()?) -} - -impl ResultExpr { - /// Transform the `RecordBatch` result of a query using the `transformation` expression - pub fn transform_results(&self, result_batch: RecordBatch) -> Option { - self.transformation.apply_transformation(result_batch) - } -} diff --git a/crates/proof-of-sql/src/sql/transform/select_expr.rs b/crates/proof-of-sql/src/sql/transform/select_expr.rs deleted file mode 100644 index 3c407fd40..000000000 --- a/crates/proof-of-sql/src/sql/transform/select_expr.rs +++ /dev/null @@ -1,74 +0,0 @@ -#[allow(deprecated)] -use super::DataFrameExpr; -use super::{ - record_batch_expr::RecordBatchExpr, - result_expr::{lazy_frame_to_record_batch, record_batch_to_lazy_frame}, - ToPolarsExpr, -}; -use arrow::record_batch::RecordBatch; -use dyn_partial_eq::DynPartialEq; -use polars::prelude::{Expr, LazyFrame}; -use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, Expression}; -use serde::{Deserialize, Serialize}; - -/// The select expression used to select, reorder, and apply alias transformations -#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] -pub struct SelectExpr { - /// The schema of the resulting lazy frame - result_schema: Vec, -} - -impl SelectExpr { - #[cfg(test)] - pub(crate) fn new(exprs: &[impl ToPolarsExpr]) -> Self { - Self::new_from_to_polars(exprs) - } - fn new_from_to_polars(exprs: &[impl ToPolarsExpr]) -> Self { - let result_schema = Vec::from_iter(exprs.iter().map(ToPolarsExpr::to_polars_expr)); - assert!(!result_schema.is_empty()); - Self { result_schema } - } - /// Create a new select expression from a slice of AliasedResultExpr - pub fn new_from_aliased_result_exprs(aliased_exprs: &[AliasedResultExpr]) -> Self { - Self::new_from_to_polars(aliased_exprs) - } - /// Create a new select expression from a slice of Expressions - pub fn new_from_expressions(exprs: &[Expression]) -> Self { - Self::new_from_to_polars(exprs) - } -} - -#[allow(deprecated)] -impl DataFrameExpr for SelectExpr { - /// Apply the select transformation to the lazy frame - fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame { - lazy_frame.select(&self.result_schema) - } -} - -#[typetag::serde] -impl RecordBatchExpr for SelectExpr { - fn apply_transformation(&self, record_batch: RecordBatch) -> Option { - let easy_result: Option> = self - .result_schema - .iter() - .cloned() - .map(|expr| match expr { - Expr::Alias(a, b) => match *a { - Expr::Column(c) if c == b => { - Some((b.to_owned(), record_batch.column_by_name(&b)?.to_owned())) - } - _ => None, - }, - _ => None, - }) - .collect(); - - if let Some(Ok(result)) = easy_result.map(RecordBatch::try_from_iter) { - return Some(result); - } - let (lazy_frame, num_input_rows) = record_batch_to_lazy_frame(record_batch)?; - #[allow(deprecated)] - lazy_frame_to_record_batch(self.lazy_transformation(lazy_frame, num_input_rows)) - } -} diff --git a/crates/proof-of-sql/src/sql/transform/select_expr_test.rs b/crates/proof-of-sql/src/sql/transform/select_expr_test.rs deleted file mode 100644 index fd2b3f3aa..000000000 --- a/crates/proof-of-sql/src/sql/transform/select_expr_test.rs +++ /dev/null @@ -1,139 +0,0 @@ -use crate::{ - record_batch, - sql::transform::{test_utility::*, ResultExpr}, -}; -use arrow::record_batch::RecordBatch; - -#[test] -fn we_can_filter_out_record_batch_columns() { - let data = record_batch!("c" => [-5_i64, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - let result_expr = ResultExpr::new(select(&[col("a").alias("a2")])); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a2" => ["d", "a", "f", "b"]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_filter_out_record_batch_columns_with_i128_data() { - let data = record_batch!("c" => [-5_i128, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - let result_expr = ResultExpr::new(select(&[col("a").alias("a2")])); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a2" => ["d", "a", "f", "b"]); - assert_eq!(data, expected_data); -} - -#[test] -#[should_panic] -fn result_expr_panics_with_batches_containing_duplicate_columns() { - let data = record_batch!("a" => [-5_i64, 1, -56, 2], "a" => [-5_i64, 1, -56, 2]); - let result_expr = ResultExpr::new(select(&[col("a").alias("a2"), col("a").alias("a3")])); - result_expr.transform_results(data).unwrap(); -} - -#[test] -fn we_can_reorder_the_record_batch_columns_without_changing_their_names() { - let data = record_batch!("c" => [-5_i64, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - let result_expr = ResultExpr::new(select(&[col("a").alias("a"), col("c").alias("c")])); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("a" => ["d", "a", "f", "b"], "c" => [-5_i64, 1, -56, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_remap_the_record_batch_columns_to_different_names() { - let data = record_batch!("c" => [-5_i64, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - let result_expr = ResultExpr::new(select(&[ - col("a").alias("b_test"), - col("c").alias("col_c_test"), - ])); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = - record_batch!("b_test" => ["d", "a", "f", "b"], "col_c_test" => [-5_i64, 1, -56, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_remap_the_record_batch_columns_to_new_columns() { - let data = record_batch!("c" => [-5_i64, 1, -56, 2], "a" => ["d", "a", "f", "b"]); - let result_expr = ResultExpr::new(select(&[ - col("c").alias("abc"), - col("a").alias("b_test"), - col("a").alias("d2"), - col("c").alias("c"), - ])); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!("abc" => [-5_i64, 1, -56, 2], "b_test" => ["d", "a", "f", "b"], "d2" => ["d", "a", "f", "b"], "c" => [-5_i64, 1, -56, 2]); - assert_eq!(data, expected_data); -} - -#[test] -fn we_can_use_agg_with_select_expression() { - let data = record_batch!( - "c" => [1_i64, -5, -3, 7, -3], - "a" => [1_i64, 2, 3, 1, 3], - "d" => [523_i128, -25, 343, -7, 435], - "h" => [-1_i128, -2, -3, -1, -3], - "y" => ["a", "b", "c", "d", "e"] - ); - let result_expr = ResultExpr::new(select(&[ - col("c").sum().alias("c_sum"), - col("a").max().alias("a_max"), - col("d").min().alias("d_min"), - col("h").count().alias("h_count"), - (col("c").sum() * col("a").max() - col("d").min() + col("h").count() * lit(2) - lit(733)) - .alias("expr"), - ])); - let data = result_expr.transform_results(data).unwrap(); - let expected_data = record_batch!( - "c_sum" => [-3_i64], - "a_max" => [3_i64], - "d_min" => [-25_i128], - "h_count" => [5_i64], - "expr" => [-707_i128], - ); - assert_eq!(data, expected_data); -} - -#[test] -fn using_count_with_an_empty_batch_will_return_zero() { - let data = record_batch!("i" => [-5_i64], "d" => [3_i128], "s" => ["a"]); - let empty_data = RecordBatch::new_empty(data.schema()); - let result_expr = ResultExpr::new(select(&[ - col("i").count(), - col("d").count(), - col("s").count(), - ])); - let data = result_expr.transform_results(empty_data).unwrap(); - let expected_data = record_batch!("i" => [0_i64], "d" => [0_i64], "s" => [0_i64]); - assert_eq!(data, expected_data); -} - -#[test] -fn using_sum_with_an_empty_batch_will_return_zero() { - let data = record_batch!("i" => [-5_i64], "d" => [3_i128]); - let empty_data = RecordBatch::new_empty(data.schema()); - let result_expr = ResultExpr::new(select(&[col("i").sum(), col("d").sum()])); - let data = result_expr.transform_results(empty_data).unwrap(); - let expected_data = record_batch!("i" => [0_i64], "d" => [0_i128]); - assert_eq!(data, expected_data); -} - -#[test] -fn using_min_with_an_empty_batch_will_return_empty_even_with_count_or_sum_in_the_result() { - let data = record_batch!("i" => [-5_i64], "d" => [3_i128], "i1" => [3_i64]); - let empty_data = RecordBatch::new_empty(data.schema()); - let result_expr = ResultExpr::new(select(&[col("i").count(), col("d").sum(), col("i1").min()])); - let data = result_expr.transform_results(empty_data.clone()).unwrap(); - let expected_data = empty_data; - assert_eq!(data, expected_data); -} - -#[test] -fn using_max_with_an_empty_batch_will_return_empty_even_with_count_or_sum_in_the_result() { - let data = record_batch!("i" => [-5_i64], "d" => [3_i128], "i1" => [3_i64]); - let empty_data = RecordBatch::new_empty(data.schema()); - let result_expr = ResultExpr::new(select(&[col("i").count(), col("d").sum(), col("i1").max()])); - let data = result_expr.transform_results(empty_data.clone()).unwrap(); - let expected_data = empty_data; - assert_eq!(data, expected_data); -} diff --git a/crates/proof-of-sql/src/sql/transform/slice_expr.rs b/crates/proof-of-sql/src/sql/transform/slice_expr.rs deleted file mode 100644 index 22e88af66..000000000 --- a/crates/proof-of-sql/src/sql/transform/slice_expr.rs +++ /dev/null @@ -1,40 +0,0 @@ -#[allow(deprecated)] -use super::DataFrameExpr; -use dyn_partial_eq::DynPartialEq; -use polars::prelude::LazyFrame; -use serde::{Deserialize, Serialize}; - -/// A `SliceExpr` represents a slice of a `LazyFrame`. -#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] -pub struct SliceExpr { - /// number of rows to return - /// - /// - if u64::MAX, specify all rows - number_rows: u64, - - /// number of rows to skip - /// - /// - if 0, specify the first row as starting point - /// - if negative, specify the offset from the end - /// (e.g. -1 is the last row, -2 is the second to last row, etc.) - offset_value: i64, -} - -impl SliceExpr { - /// Create a new `SliceExpr` with the given `number_rows` and `offset`. - pub fn new(number_rows: u64, offset_value: i64) -> Self { - Self { - number_rows, - offset_value, - } - } -} - -super::record_batch_expr::impl_record_batch_expr_for_data_frame_expr!(SliceExpr); -#[allow(deprecated)] -impl DataFrameExpr for SliceExpr { - /// Apply the slice transformation to the given `LazyFrame`. - fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame { - lazy_frame.slice(self.offset_value, self.number_rows) - } -} diff --git a/crates/proof-of-sql/src/sql/transform/slice_expr_test.rs b/crates/proof-of-sql/src/sql/transform/slice_expr_test.rs deleted file mode 100644 index 4ab03811c..000000000 --- a/crates/proof-of-sql/src/sql/transform/slice_expr_test.rs +++ /dev/null @@ -1,121 +0,0 @@ -use crate::{ - record_batch, - sql::transform::test_utility::{composite_result, slice}, -}; - -#[test] -fn we_can_slice_a_lazy_frame_using_only_a_positive_limit_value() { - let limit = 3_usize; - - let data_a = [123_i64, 342, -234, 777, 123, 34]; - let data_d = ["alfa", "beta", "abc", "f", "kl", "f"]; - let data_frame = record_batch!( - "a" => data_a.to_vec(), - "d" => data_d.to_vec() - ); - - let result_expr = composite_result(vec![slice(limit as u64, 0)]); - let data_frame = result_expr.transform_results(data_frame).unwrap(); - - assert_eq!( - data_frame, - record_batch!( - "a" => data_a[0..limit].to_vec(), - "d" => data_d[0..limit].to_vec() - ) - ); -} - -#[test] -fn we_can_slice_a_lazy_frame_using_only_a_zero_limit_value() { - let limit = 0; - - let data_a = [123_i64, 342, -234, 777, 123, 34]; - let data_d = ["alfa", "beta", "abc", "f", "kl", "f"]; - let data_frame = record_batch!( - "a" => data_a.to_vec(), - "d" => data_d.to_vec() - ); - - let result_expr = composite_result(vec![slice(limit as u64, 0)]); - let data_frame = result_expr.transform_results(data_frame).unwrap(); - - assert_eq!( - data_frame, - record_batch!( - "a" => Vec::::new(), - "d" => Vec::::new() - ) - ); -} - -#[test] -fn we_can_slice_a_lazy_frame_using_only_a_positive_offset_value() { - let offset = 3; - - let data_a = [123_i64, 342, -234, 777, 123, 34]; - let data_d = ["alfa", "beta", "abc", "f", "kl", "f"]; - let data_frame = record_batch!( - "a" => data_a.to_vec(), - "d" => data_d.to_vec() - ); - - let result_expr = composite_result(vec![slice(u64::MAX, offset)]); - let data_frame = result_expr.transform_results(data_frame).unwrap(); - - assert_eq!( - data_frame, - record_batch!( - "a" => data_a[(offset as usize)..].to_vec(), - "d" => data_d[(offset as usize)..].to_vec() - ) - ); -} - -#[test] -fn we_can_slice_a_lazy_frame_using_only_a_negative_offset_value() { - let offset = -2; - - let data_a = [123_i64, 342, -234, 777, 123, 34]; - let data_d = ["alfa", "beta", "abc", "f", "kl", "f"]; - let data_frame = record_batch!( - "a" => data_a.to_vec(), - "d" => data_d.to_vec() - ); - - let result_expr = composite_result(vec![slice(u64::MAX, offset)]); - let data_frame = result_expr.transform_results(data_frame).unwrap(); - - assert_eq!( - data_frame, - record_batch!( - "a" => data_a[(data_a.len() as i64 + offset) as usize..].to_vec(), - "d" => data_d[(data_a.len() as i64 + offset) as usize..].to_vec() - ) - ); -} - -#[test] -fn we_can_slice_a_lazy_frame_using_both_limit_and_offset_values() { - let offset = -2; - let limit = 1_usize; - - let data_a = [123_i64, 342, -234, 777, 123, 34]; - let data_d = ["alfa", "beta", "abc", "f", "kl", "f"]; - let data_frame = record_batch!( - "a" => data_a.to_vec(), - "d" => data_d.to_vec() - ); - - let result_expr = composite_result(vec![slice(limit as u64, offset)]); - let data_frame = result_expr.transform_results(data_frame).unwrap(); - let beg_expected_index = (data_a.len() as i64 + offset) as usize; - - assert_eq!( - data_frame, - record_batch!( - "a" => data_a[beg_expected_index..(beg_expected_index + limit)].to_vec(), - "d" => data_d[beg_expected_index..(beg_expected_index + limit)].to_vec() - ) - ); -} diff --git a/crates/proof-of-sql/src/sql/transform/test_utility.rs b/crates/proof-of-sql/src/sql/transform/test_utility.rs deleted file mode 100644 index d5f1c5b30..000000000 --- a/crates/proof-of-sql/src/sql/transform/test_utility.rs +++ /dev/null @@ -1,84 +0,0 @@ -use super::*; -use proof_of_sql_parser::intermediate_ast::*; - -pub fn lit_i64(literal: i64) -> Box { - Box::new(Expression::Literal(Literal::BigInt(literal))) -} - -pub fn lit>(literal: L) -> Box { - Box::new(Expression::Literal(literal.into())) -} -pub trait ToLit { - fn to_lit(self) -> Box; -} -impl ToLit for i64 { - fn to_lit(self) -> Box { - lit_i64(self) - } -} -pub fn col(name: &str) -> Box { - Box::new(Expression::Column(name.parse().unwrap())) -} - -pub(crate) fn select(result_schema: &[impl ToPolarsExpr]) -> Box { - #[allow(deprecated)] - Box::new(SelectExpr::new(result_schema)) -} - -pub fn schema(columns: &[(&str, &str)]) -> Vec { - columns - .iter() - .map(|(name, alias)| col(name).alias(alias)) - .collect() -} - -pub fn result(columns: &[(&str, &str)]) -> ResultExpr { - let mut composition = CompositionExpr::default(); - composition.add(Box::new(SelectExpr::new_from_aliased_result_exprs( - &schema(columns), - ))); - ResultExpr::new(Box::new(composition)) -} - -pub fn slice(limit: u64, offset: i64) -> Box { - Box::new(SliceExpr::new(limit, offset)) -} - -pub fn composite_result(transformations: Vec>) -> ResultExpr { - let mut composition = CompositionExpr::default(); - - for transformation in transformations { - composition.add(transformation); - } - - ResultExpr::new(Box::new(composition)) -} - -pub fn orders(cols: &[&str], directions: &[OrderByDirection]) -> Box { - let by_exprs = cols - .iter() - .zip(directions.iter()) - .map(|(col, direction)| OrderBy { - expr: col.parse().unwrap(), - direction: *direction, - }) - .collect(); - - Box::new(OrderByExprs::new(by_exprs)) -} - -pub fn groupby< - T: IntoIterator>, - A: IntoIterator, ->( - by_exprs: T, - agg_exprs: A, -) -> Box { - Box::new(GroupByExpr::new( - &Vec::from_iter(by_exprs.into_iter().map(|expr| match *expr { - Expression::Column(c) => c, - _ => panic!("Expected column expression"), - })), - &Vec::from_iter(agg_exprs), - )) -} diff --git a/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs b/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs deleted file mode 100644 index c18e4d837..000000000 --- a/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs +++ /dev/null @@ -1,60 +0,0 @@ -use super::{polars_arithmetic::SafeDivision, polars_conversions::LiteralConversion}; -use polars::prelude::{col, Expr}; -use proof_of_sql_parser::intermediate_ast::*; -pub(crate) trait ToPolarsExpr { - fn to_polars_expr(&self) -> Expr; -} -#[cfg(test)] -impl ToPolarsExpr for Expr { - fn to_polars_expr(&self) -> Expr { - self.clone() - } -} -#[cfg(test)] -impl ToPolarsExpr for Box { - fn to_polars_expr(&self) -> Expr { - self.as_ref().to_polars_expr() - } -} -impl ToPolarsExpr for AliasedResultExpr { - fn to_polars_expr(&self) -> Expr { - self.expr.to_polars_expr().alias(self.alias.as_str()) - } -} -impl ToPolarsExpr for Expression { - fn to_polars_expr(&self) -> Expr { - match self { - Expression::Literal(literal) => match literal { - Literal::Boolean(value) => value.to_lit(), - Literal::BigInt(value) => value.to_lit(), - Literal::Int128(value) => value.to_lit(), - Literal::VarChar(_) => panic!("Expression not supported"), - Literal::Decimal(_) => todo!(), - Literal::Timestamp(_) => panic!("Expression not supported"), - }, - Expression::Column(identifier) => col(identifier.as_str()), - Expression::Binary { op, left, right } => { - let left = left.to_polars_expr(); - let right = right.to_polars_expr(); - match op { - BinaryOperator::Add => left + right, - BinaryOperator::Subtract => left - right, - BinaryOperator::Multiply => left * right, - BinaryOperator::Division => left.checked_div(right), - _ => panic!("Operation not supported yet"), - } - } - Expression::Aggregation { op, expr } => { - let expr = expr.to_polars_expr(); - match op { - AggregationOperator::Count => expr.count(), - AggregationOperator::Sum => expr.sum(), - AggregationOperator::Min => expr.min(), - AggregationOperator::Max => expr.max(), - AggregationOperator::First => expr.first(), - } - } - _ => panic!("Operation not supported"), - } - } -}