Skip to content

Commit

Permalink
refactor!: proof_of_sql_parser::intermediate_ast::OrderBy with `sql…
Browse files Browse the repository at this point in the history
…parser::ast::OrderByExpr` in the proof-of-sql crate
  • Loading branch information
varshith257 committed Jan 9, 2025
1 parent 8baa9df commit 4775126
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 50 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* text=auto
15 changes: 15 additions & 0 deletions crates/proof-of-sql-parser/src/intermediate_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,21 @@ pub enum OrderByDirection {
Desc,
}

/// Extension trait for `OrderByDirection` to provide utility methods.
pub trait OrderByDirectionExt {
/// Converts `OrderByDirection` to `Option<bool>` for compatibility.
fn to_option_bool(&self) -> Option<bool>;
}

impl OrderByDirectionExt for OrderByDirection {
fn to_option_bool(&self) -> Option<bool> {
match self {
OrderByDirection::Asc => Some(true),
OrderByDirection::Desc => Some(false),
}
}
}

impl Display for OrderByDirection {
// This trait requires `fmt` with this exact signature.
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Expand Down
39 changes: 32 additions & 7 deletions crates/proof-of-sql/src/base/database/order_by_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::base::{
};
use alloc::vec::Vec;
use core::cmp::Ordering;
use proof_of_sql_parser::intermediate_ast::OrderByDirection;
use sqlparser::ast::{Expr, OrderByExpr};

/// Compares the tuples `(order_by[0][i], order_by[1][i], ...)` and
/// `(order_by[0][j], order_by[1][j], ...)` in lexicographic order.
Expand Down Expand Up @@ -110,22 +110,47 @@ pub(crate) fn compare_indexes_by_owned_columns<S: Scalar>(
) -> Ordering {
let order_by_pairs = order_by
.iter()
.map(|&col| (col.clone(), OrderByDirection::Asc))
.map(|&col| {
(
col.clone(),
OrderByExpr {
expr: owned_column_to_expr(col),
asc: Some(true),
nulls_first: None,
},
)
})
.collect::<Vec<_>>();
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, i, j)
}

/// Converts an `OwnedColumn` into an SQL `Expr` for use in `OrderByExpr`.
fn owned_column_to_expr<S: Scalar>(col: &OwnedColumn<S>) -> Expr {
match col {
OwnedColumn::Boolean(_) => Expr::Identifier("BooleanColumn".into()),
OwnedColumn::TinyInt(_) => Expr::Identifier("TinyIntColumn".into()),
OwnedColumn::SmallInt(_) => Expr::Identifier("SmallIntColumn".into()),
OwnedColumn::Int(_) => Expr::Identifier("IntColumn".into()),
OwnedColumn::BigInt(_) => Expr::Identifier("BigIntColumn".into()),
OwnedColumn::VarChar(_) => Expr::Identifier("VarCharColumn".into()),
OwnedColumn::Int128(_) => Expr::Identifier("Int128Column".into()),
OwnedColumn::Decimal75(_, _, _) => Expr::Identifier("DecimalColumn".into()),
OwnedColumn::Scalar(_) => Expr::Identifier("ScalarColumn".into()),
OwnedColumn::TimestampTZ(_, _, _) => Expr::Identifier("TimestampColumn".into()),
}
}

/// Compares the tuples `(left[0][i], left[1][i], ...)` and
/// `(right[0][j], right[1][j], ...)` in lexicographic order.
/// Note that direction flips the ordering.
pub(crate) fn compare_indexes_by_owned_columns_with_direction<S: Scalar>(
order_by_pairs: &[(OwnedColumn<S>, OrderByDirection)],
order_by_pairs: &[(OwnedColumn<S>, OrderByExpr)],
i: usize,
j: usize,
) -> Ordering {
order_by_pairs
.iter()
.map(|(col, direction)| {
.map(|(col, order_by_expr)| {
let ordering = match col {
OwnedColumn::Boolean(col) => col[i].cmp(&col[j]),
OwnedColumn::TinyInt(col) => col[i].cmp(&col[j]),
Expand All @@ -139,9 +164,9 @@ pub(crate) fn compare_indexes_by_owned_columns_with_direction<S: Scalar>(
OwnedColumn::Scalar(col) => col[i].cmp(&col[j]),
OwnedColumn::VarChar(col) => col[i].cmp(&col[j]),
};
match direction {
OrderByDirection::Asc => ordering,
OrderByDirection::Desc => ordering.reverse(),
match order_by_expr.asc {
Some(false) => ordering.reverse(), // DESC
None | Some(true) => ordering, // Default to ASC or explicitly ASC
}
})
.find(|&ord| ord != Ordering::Equal)
Expand Down
29 changes: 25 additions & 4 deletions crates/proof-of-sql/src/base/database/order_by_util_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
proof_primitive::dory::DoryScalar,
};
use core::cmp::Ordering;
use proof_of_sql_parser::intermediate_ast::OrderByDirection;
use sqlparser::ast::{Expr, OrderByExpr};

#[test]
fn we_can_compare_indexes_by_columns_with_no_columns() {
Expand Down Expand Up @@ -267,9 +267,30 @@ fn we_can_compare_columns_with_direction() {
.collect(),
);
let order_by_pairs = vec![
(col1, OrderByDirection::Asc),
(col2, OrderByDirection::Desc),
(col3, OrderByDirection::Asc),
(
col1,
OrderByExpr {
expr: Expr::Identifier("SmallIntColumn".into()),
asc: Some(true), // Ascending
nulls_first: None,
},
),
(
col2,
OrderByExpr {
expr: Expr::Identifier("VarCharColumn".into()),
asc: Some(false), // Descending
nulls_first: None,
},
),
(
col3,
OrderByExpr {
expr: Expr::Identifier("DecimalColumn".into()),
asc: Some(true), // Ascending
nulls_first: None,
},
),
];
// Equal on col1 and col2, less on col3
assert_eq!(
Expand Down
17 changes: 10 additions & 7 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use crate::{
};
use alloc::{borrow::ToOwned, boxed::Box, string::ToString, vec::Vec};
use proof_of_sql_parser::intermediate_ast::{
AggregationOperator, AliasedResultExpr, Expression, OrderBy, Slice,
AggregationOperator, AliasedResultExpr, Expression, Slice,
};
use sqlparser::ast::Ident;
use sqlparser::ast::{Expr, Ident, OrderByExpr};

#[derive(Default, Debug)]
pub struct QueryContext {
Expand All @@ -24,7 +24,7 @@ pub struct QueryContext {
table: Option<TableRef>,
in_result_scope: bool,
has_visited_group_by: bool,
order_by_exprs: Vec<OrderBy>,
order_by_exprs: Vec<OrderByExpr>,
group_by_exprs: Vec<Ident>,
where_expr: Option<Box<Expression>>,
result_column_set: IndexSet<Ident>,
Expand Down Expand Up @@ -136,7 +136,7 @@ impl QueryContext {
self.has_visited_group_by = true;
}

pub fn set_order_by_exprs(&mut self, order_by_exprs: Vec<OrderBy>) {
pub fn set_order_by_exprs(&mut self, order_by_exprs: Vec<OrderByExpr>) {
self.order_by_exprs = order_by_exprs;
}

Expand Down Expand Up @@ -195,14 +195,17 @@ impl QueryContext {
Ok(&self.res_aliased_exprs)
}

pub fn get_order_by_exprs(&self) -> ConversionResult<Vec<OrderBy>> {
pub fn get_order_by_exprs(&self) -> ConversionResult<Vec<OrderByExpr>> {
// Order by must reference only aliases in the result schema
for by_expr in &self.order_by_exprs {
self.res_aliased_exprs
.iter()
.find(|col| col.alias == by_expr.expr)
.find(|col| match &by_expr.expr {
Expr::Identifier(ident) => *ident.value == *col.alias,
_ => false,
})
.ok_or(ConversionError::InvalidOrderBy {
alias: by_expr.expr.as_str().to_string(),
alias: by_expr.expr.to_string(),
})?;
}

Expand Down
8 changes: 4 additions & 4 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::base::{
use alloc::{boxed::Box, format, string::ToString, vec::Vec};
use proof_of_sql_parser::{
intermediate_ast::{
AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr,
Slice, TableExpression,
AggregationOperator, AliasedResultExpr, Expression, Literal, SelectResultExpr, Slice,
TableExpression,
},
Identifier, ResourceId,
};
Expand All @@ -22,7 +22,7 @@ pub struct QueryContextBuilder<'a> {
context: QueryContext,
schema_accessor: &'a dyn SchemaAccessor,
}
use sqlparser::ast::Ident;
use sqlparser::ast::{Ident, OrderByExpr};

// Public interface
impl<'a> QueryContextBuilder<'a> {
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<'a> QueryContextBuilder<'a> {
Ok(self)
}

pub fn visit_order_by_exprs(mut self, order_by_exprs: Vec<OrderBy>) -> Self {
pub fn visit_order_by_exprs(mut self, order_by_exprs: Vec<OrderByExpr>) -> Self {
self.context.set_order_by_exprs(order_by_exprs);
self
}
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/parse/query_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl QueryExpr {
.visit_group_by_exprs(group_by.into_iter().map(Ident::from).collect())?
.visit_result_exprs(result_exprs)?
.visit_where_expr(where_expr)?
.visit_order_by_exprs(ast.order_by)
.visit_order_by_exprs(ast.order_by.into_iter().map(Into::into).collect())
.visit_slice_expr(ast.slice)
.build()?,
};
Expand Down
6 changes: 6 additions & 0 deletions crates/proof-of-sql/src/sql/postprocessing/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ pub enum PostprocessingError {
/// The column which is not found
column: String,
},
/// Invalid expression encountered
#[snafu(display("Invalid expression encountered: {expression}"))]
InvalidExpression {
/// The invalid expression
expression: String,
},
/// Errors in evaluation of `Expression`s
#[snafu(transparent)]
ExpressionEvaluationError {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ use crate::base::{
scalar::Scalar,
};
use alloc::{string::ToString, vec::Vec};
use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection};
use serde::{Deserialize, Serialize};
use sqlparser::ast::{Expr, OrderByExpr};

/// A node representing a list of `OrderBy` expressions.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderByPostprocessing {
by_exprs: Vec<OrderBy>,
order_by_exprs: Vec<OrderByExpr>,
}

impl OrderByPostprocessing {
/// Create a new `OrderByPostprocessing` node.
#[must_use]
pub fn new(by_exprs: Vec<OrderBy>) -> Self {
Self { by_exprs }
pub fn new(order_by_exprs: Vec<OrderByExpr>) -> Self {
Self { order_by_exprs }
}
}

Expand All @@ -29,25 +29,38 @@ impl<S: Scalar> PostprocessingStep<S> for OrderByPostprocessing {
fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
// Evaluate the columns by which we order
// Once we allow OrderBy for general aggregation-free expressions here we will need to call eval()
let order_by_pairs: Vec<(OwnedColumn<S>, OrderByDirection)> = self
.by_exprs
let order_by_pairs: Vec<(OwnedColumn<S>, OrderByExpr)> = self
.order_by_exprs
.iter()
.map(
|order_by| -> PostprocessingResult<(OwnedColumn<S>, OrderByDirection)> {
let identifier: sqlparser::ast::Ident = order_by.expr.into();
Ok((
owned_table
.inner_table()
.get(&identifier)
.ok_or(PostprocessingError::ColumnNotFound {
column: order_by.expr.to_string(),
})?
.clone(),
order_by.direction,
))
|order_by| -> PostprocessingResult<(OwnedColumn<S>, OrderByExpr)> {
let identifier = match &order_by.expr {
Expr::Identifier(ident) => ident.clone(),
_ => {
return Err(PostprocessingError::InvalidExpression {
expression: order_by.expr.to_string(),
});
}
};

let column = owned_table
.inner_table()
.get(&identifier)
.ok_or(PostprocessingError::ColumnNotFound {
column: order_by.expr.to_string(),
})?
.clone();

let order_by_expr = OrderByExpr {
expr: Expr::Identifier(identifier.clone()),
asc: order_by.asc,
nulls_first: order_by.nulls_first,
};

Ok((column, order_by_expr))
},
)
.collect::<PostprocessingResult<Vec<(OwnedColumn<S>, OrderByDirection)>>>()?;
.collect::<PostprocessingResult<Vec<(OwnedColumn<S>, OrderByExpr)>>>()?;
// Define the ordering
let permutation = Permutation::unchecked_new_from_cmp(owned_table.num_rows(), |&a, &b| {
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
Expand Down
24 changes: 16 additions & 8 deletions crates/proof-of-sql/src/sql/postprocessing/test_utility.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::*;
use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection};
use sqlparser::ast::Ident;
use proof_of_sql_parser::intermediate_ast::{
AliasedResultExpr, OrderByDirection, OrderByDirectionExt,
};
use sqlparser::ast::{Expr, Ident, OrderByExpr};

#[must_use]
pub fn group_by_postprocessing(
Expand Down Expand Up @@ -29,13 +31,19 @@ pub fn slice(limit: Option<u64>, offset: Option<i64>) -> OwnedTablePostprocessin

#[must_use]
pub fn orders(cols: &[&str], directions: &[OrderByDirection]) -> OwnedTablePostprocessing {
let by_exprs = cols
let directions_as_bool: Vec<Option<bool>> = directions
.iter()
.zip(directions.iter())
.map(|(col, direction)| OrderBy {
expr: col.parse().unwrap(),
direction: *direction,
.map(OrderByDirectionExt::to_option_bool)
.collect();
let order_by_exprs: Vec<OrderByExpr> = cols
.iter()
.zip(directions_as_bool.iter())
.map(|(col, &direction)| OrderByExpr {
expr: Expr::Identifier((*col).into()),
asc: direction,
nulls_first: None,
})
.collect();
OwnedTablePostprocessing::new_order_by(OrderByPostprocessing::new(by_exprs))

OwnedTablePostprocessing::new_order_by(OrderByPostprocessing::new(order_by_exprs))
}

0 comments on commit 4775126

Please sign in to comment.