From 981cf7350435f2d21f30125c89a2388d636cd200 Mon Sep 17 00:00:00 2001 From: Vamshi Maskuri <117595548+varshith257@users.noreply.github.com> Date: Fri, 3 Jan 2025 17:02:00 +0530 Subject: [PATCH] refactor!: `proof_of_sql_parser::intermediate_ast::OrderBy` with `sqlparser::ast::OrderByExpr` in the proof-of-sql crate --- .../src/intermediate_ast.rs | 15 ++++++ .../src/base/database/order_by_util.rs | 39 +++++++++++--- .../src/base/database/order_by_util_test.rs | 29 +++++++++-- .../src/sql/parse/query_context.rs | 17 ++++--- .../src/sql/parse/query_context_builder.rs | 8 +-- .../proof-of-sql/src/sql/parse/query_expr.rs | 2 +- .../src/sql/postprocessing/error.rs | 6 +++ .../postprocessing/order_by_postprocessing.rs | 51 ++++++++++++------- .../src/sql/postprocessing/test_utility.rs | 24 ++++++--- 9 files changed, 141 insertions(+), 50 deletions(-) diff --git a/crates/proof-of-sql-parser/src/intermediate_ast.rs b/crates/proof-of-sql-parser/src/intermediate_ast.rs index d89696654..822d450c9 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast.rs @@ -308,6 +308,21 @@ pub enum OrderByDirection { Desc, } +/// Extension trait for `OrderByDirection` to provide utility methods. +pub trait OrderByDirectionExt { + /// Converts `OrderByDirection` to `Option` for compatibility. + fn to_option_bool(&self) -> Option; +} + +impl OrderByDirectionExt for OrderByDirection { + fn to_option_bool(&self) -> Option { + match self { + OrderByDirection::Asc => Some(true), + OrderByDirection::Desc => Some(false), + } + } +} + impl Display for OrderByDirection { // This trait requires `fmt` with this exact signature. fn fmt(&self, f: &mut Formatter) -> fmt::Result { diff --git a/crates/proof-of-sql/src/base/database/order_by_util.rs b/crates/proof-of-sql/src/base/database/order_by_util.rs index 325ed49f9..b156e3d3d 100644 --- a/crates/proof-of-sql/src/base/database/order_by_util.rs +++ b/crates/proof-of-sql/src/base/database/order_by_util.rs @@ -5,7 +5,7 @@ use crate::base::{ }; use alloc::vec::Vec; use core::cmp::Ordering; -use proof_of_sql_parser::intermediate_ast::OrderByDirection; +use sqlparser::ast::{Expr, OrderByExpr}; /// Compares the tuples `(order_by[0][i], order_by[1][i], ...)` and /// `(order_by[0][j], order_by[1][j], ...)` in lexicographic order. @@ -110,22 +110,47 @@ pub(crate) fn compare_indexes_by_owned_columns( ) -> Ordering { let order_by_pairs = order_by .iter() - .map(|&col| (col.clone(), OrderByDirection::Asc)) + .map(|&col| { + ( + col.clone(), + OrderByExpr { + expr: owned_column_to_expr(col), + asc: Some(true), + nulls_first: None, + }, + ) + }) .collect::>(); compare_indexes_by_owned_columns_with_direction(&order_by_pairs, i, j) } +/// Converts an `OwnedColumn` into an SQL `Expr` for use in `OrderByExpr`. +fn owned_column_to_expr(col: &OwnedColumn) -> Expr { + match col { + OwnedColumn::Boolean(_) => Expr::Identifier("BooleanColumn".into()), + OwnedColumn::TinyInt(_) => Expr::Identifier("TinyIntColumn".into()), + OwnedColumn::SmallInt(_) => Expr::Identifier("SmallIntColumn".into()), + OwnedColumn::Int(_) => Expr::Identifier("IntColumn".into()), + OwnedColumn::BigInt(_) => Expr::Identifier("BigIntColumn".into()), + OwnedColumn::VarChar(_) => Expr::Identifier("VarCharColumn".into()), + OwnedColumn::Int128(_) => Expr::Identifier("Int128Column".into()), + OwnedColumn::Decimal75(_, _, _) => Expr::Identifier("DecimalColumn".into()), + OwnedColumn::Scalar(_) => Expr::Identifier("ScalarColumn".into()), + OwnedColumn::TimestampTZ(_, _, _) => Expr::Identifier("TimestampColumn".into()), + } +} + /// Compares the tuples `(left[0][i], left[1][i], ...)` and /// `(right[0][j], right[1][j], ...)` in lexicographic order. /// Note that direction flips the ordering. pub(crate) fn compare_indexes_by_owned_columns_with_direction( - order_by_pairs: &[(OwnedColumn, OrderByDirection)], + order_by_pairs: &[(OwnedColumn, OrderByExpr)], i: usize, j: usize, ) -> Ordering { order_by_pairs .iter() - .map(|(col, direction)| { + .map(|(col, order_by_expr)| { let ordering = match col { OwnedColumn::Boolean(col) => col[i].cmp(&col[j]), OwnedColumn::TinyInt(col) => col[i].cmp(&col[j]), @@ -139,9 +164,9 @@ pub(crate) fn compare_indexes_by_owned_columns_with_direction( OwnedColumn::Scalar(col) => col[i].cmp(&col[j]), OwnedColumn::VarChar(col) => col[i].cmp(&col[j]), }; - match direction { - OrderByDirection::Asc => ordering, - OrderByDirection::Desc => ordering.reverse(), + match order_by_expr.asc { + Some(false) => ordering.reverse(), // DESC + None | Some(true) => ordering, // Default to ASC or explicitly ASC } }) .find(|&ord| ord != Ordering::Equal) diff --git a/crates/proof-of-sql/src/base/database/order_by_util_test.rs b/crates/proof-of-sql/src/base/database/order_by_util_test.rs index 81a02862d..4f101269b 100644 --- a/crates/proof-of-sql/src/base/database/order_by_util_test.rs +++ b/crates/proof-of-sql/src/base/database/order_by_util_test.rs @@ -7,7 +7,7 @@ use crate::{ proof_primitive::dory::DoryScalar, }; use core::cmp::Ordering; -use proof_of_sql_parser::intermediate_ast::OrderByDirection; +use sqlparser::ast::{Expr, OrderByExpr}; #[test] fn we_can_compare_indexes_by_columns_with_no_columns() { @@ -267,9 +267,30 @@ fn we_can_compare_columns_with_direction() { .collect(), ); let order_by_pairs = vec![ - (col1, OrderByDirection::Asc), - (col2, OrderByDirection::Desc), - (col3, OrderByDirection::Asc), + ( + col1, + OrderByExpr { + expr: Expr::Identifier("SmallIntColumn".into()), + asc: Some(true), // Ascending + nulls_first: None, + }, + ), + ( + col2, + OrderByExpr { + expr: Expr::Identifier("VarCharColumn".into()), + asc: Some(false), // Descending + nulls_first: None, + }, + ), + ( + col3, + OrderByExpr { + expr: Expr::Identifier("DecimalColumn".into()), + asc: Some(true), // Ascending + nulls_first: None, + }, + ), ]; // Equal on col1 and col2, less on col3 assert_eq!( diff --git a/crates/proof-of-sql/src/sql/parse/query_context.rs b/crates/proof-of-sql/src/sql/parse/query_context.rs index 3c1b7c551..8f25fa111 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context.rs @@ -11,9 +11,9 @@ use crate::{ }; use alloc::{borrow::ToOwned, boxed::Box, string::ToString, vec::Vec}; use proof_of_sql_parser::intermediate_ast::{ - AggregationOperator, AliasedResultExpr, Expression, OrderBy, Slice, + AggregationOperator, AliasedResultExpr, Expression, Slice, }; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident, OrderByExpr}; #[derive(Default, Debug)] pub struct QueryContext { @@ -24,7 +24,7 @@ pub struct QueryContext { table: Option, in_result_scope: bool, has_visited_group_by: bool, - order_by_exprs: Vec, + order_by_exprs: Vec, group_by_exprs: Vec, where_expr: Option>, result_column_set: IndexSet, @@ -136,7 +136,7 @@ impl QueryContext { self.has_visited_group_by = true; } - pub fn set_order_by_exprs(&mut self, order_by_exprs: Vec) { + pub fn set_order_by_exprs(&mut self, order_by_exprs: Vec) { self.order_by_exprs = order_by_exprs; } @@ -195,14 +195,17 @@ impl QueryContext { Ok(&self.res_aliased_exprs) } - pub fn get_order_by_exprs(&self) -> ConversionResult> { + pub fn get_order_by_exprs(&self) -> ConversionResult> { // Order by must reference only aliases in the result schema for by_expr in &self.order_by_exprs { self.res_aliased_exprs .iter() - .find(|col| col.alias == by_expr.expr) + .find(|col| match &by_expr.expr { + Expr::Identifier(ident) => *ident.value == *col.alias, + _ => false, + }) .ok_or(ConversionError::InvalidOrderBy { - alias: by_expr.expr.as_str().to_string(), + alias: by_expr.expr.to_string(), })?; } diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 708cb8236..356f31990 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -12,8 +12,8 @@ use crate::base::{ use alloc::{boxed::Box, format, string::ToString, vec::Vec}; use proof_of_sql_parser::{ intermediate_ast::{ - AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr, - Slice, TableExpression, + AggregationOperator, AliasedResultExpr, Expression, Literal, SelectResultExpr, Slice, + TableExpression, }, Identifier, ResourceId, }; @@ -22,7 +22,7 @@ pub struct QueryContextBuilder<'a> { context: QueryContext, schema_accessor: &'a dyn SchemaAccessor, } -use sqlparser::ast::Ident; +use sqlparser::ast::{Ident, OrderByExpr}; // Public interface impl<'a> QueryContextBuilder<'a> { @@ -77,7 +77,7 @@ impl<'a> QueryContextBuilder<'a> { Ok(self) } - pub fn visit_order_by_exprs(mut self, order_by_exprs: Vec) -> Self { + pub fn visit_order_by_exprs(mut self, order_by_exprs: Vec) -> Self { self.context.set_order_by_exprs(order_by_exprs); self } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr.rs b/crates/proof-of-sql/src/sql/parse/query_expr.rs index 798406c45..cf264a4a9 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr.rs @@ -68,7 +68,7 @@ impl QueryExpr { .visit_group_by_exprs(group_by.into_iter().map(Ident::from).collect())? .visit_result_exprs(result_exprs)? .visit_where_expr(where_expr)? - .visit_order_by_exprs(ast.order_by) + .visit_order_by_exprs(ast.order_by.into_iter().map(Into::into).collect()) .visit_slice_expr(ast.slice) .build()?, }; diff --git a/crates/proof-of-sql/src/sql/postprocessing/error.rs b/crates/proof-of-sql/src/sql/postprocessing/error.rs index 054b07358..1986ac345 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/error.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/error.rs @@ -17,6 +17,12 @@ pub enum PostprocessingError { /// The column which is not found column: String, }, + /// Invalid expression encountered + #[snafu(display("Invalid expression encountered: {expression}"))] + InvalidExpression { + /// The invalid expression + expression: String, + }, /// Errors in evaluation of `Expression`s #[snafu(transparent)] ExpressionEvaluationError { diff --git a/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs index 8303e1faf..9911690ea 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs @@ -7,20 +7,20 @@ use crate::base::{ scalar::Scalar, }; use alloc::{string::ToString, vec::Vec}; -use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection}; use serde::{Deserialize, Serialize}; +use sqlparser::ast::{Expr, OrderByExpr}; /// A node representing a list of `OrderBy` expressions. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct OrderByPostprocessing { - by_exprs: Vec, + order_by_exprs: Vec, } impl OrderByPostprocessing { /// Create a new `OrderByPostprocessing` node. #[must_use] - pub fn new(by_exprs: Vec) -> Self { - Self { by_exprs } + pub fn new(order_by_exprs: Vec) -> Self { + Self { order_by_exprs } } } @@ -29,25 +29,38 @@ impl PostprocessingStep for OrderByPostprocessing { fn apply(&self, owned_table: OwnedTable) -> PostprocessingResult> { // Evaluate the columns by which we order // Once we allow OrderBy for general aggregation-free expressions here we will need to call eval() - let order_by_pairs: Vec<(OwnedColumn, OrderByDirection)> = self - .by_exprs + let order_by_pairs: Vec<(OwnedColumn, OrderByExpr)> = self + .order_by_exprs .iter() .map( - |order_by| -> PostprocessingResult<(OwnedColumn, OrderByDirection)> { - let identifier: sqlparser::ast::Ident = order_by.expr.into(); - Ok(( - owned_table - .inner_table() - .get(&identifier) - .ok_or(PostprocessingError::ColumnNotFound { - column: order_by.expr.to_string(), - })? - .clone(), - order_by.direction, - )) + |order_by| -> PostprocessingResult<(OwnedColumn, OrderByExpr)> { + let identifier = match &order_by.expr { + Expr::Identifier(ident) => ident.clone(), + _ => { + return Err(PostprocessingError::InvalidExpression { + expression: order_by.expr.to_string(), + }); + } + }; + + let column = owned_table + .inner_table() + .get(&identifier) + .ok_or(PostprocessingError::ColumnNotFound { + column: order_by.expr.to_string(), + })? + .clone(); + + let order_by_expr = OrderByExpr { + expr: Expr::Identifier(identifier.clone()), + asc: order_by.asc, + nulls_first: order_by.nulls_first, + }; + + Ok((column, order_by_expr)) }, ) - .collect::, OrderByDirection)>>>()?; + .collect::, OrderByExpr)>>>()?; // Define the ordering let permutation = Permutation::unchecked_new_from_cmp(owned_table.num_rows(), |&a, &b| { compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b) diff --git a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs index 24f8904b4..fd09035fb 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs @@ -1,6 +1,8 @@ use super::*; -use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection}; -use sqlparser::ast::Ident; +use proof_of_sql_parser::intermediate_ast::{ + AliasedResultExpr, OrderByDirection, OrderByDirectionExt, +}; +use sqlparser::ast::{Expr, Ident, OrderByExpr}; #[must_use] pub fn group_by_postprocessing( @@ -29,13 +31,19 @@ pub fn slice(limit: Option, offset: Option) -> OwnedTablePostprocessin #[must_use] pub fn orders(cols: &[&str], directions: &[OrderByDirection]) -> OwnedTablePostprocessing { - let by_exprs = cols + let directions_as_bool: Vec> = directions .iter() - .zip(directions.iter()) - .map(|(col, direction)| OrderBy { - expr: col.parse().unwrap(), - direction: *direction, + .map(OrderByDirectionExt::to_option_bool) + .collect(); + let order_by_exprs: Vec = cols + .iter() + .zip(directions_as_bool.iter()) + .map(|(col, &direction)| OrderByExpr { + expr: Expr::Identifier((*col).into()), + asc: direction, + nulls_first: None, }) .collect(); - OwnedTablePostprocessing::new_order_by(OrderByPostprocessing::new(by_exprs)) + + OwnedTablePostprocessing::new_order_by(OrderByPostprocessing::new(order_by_exprs)) }