Skip to content

Commit

Permalink
feat!: use sql::postprocessing instead of sql::transform in `Quer…
Browse files Browse the repository at this point in the history
…yExpr` (#101)
  • Loading branch information
iajoiner authored Aug 16, 2024
1 parent 6e45265 commit 0f008d7
Show file tree
Hide file tree
Showing 12 changed files with 535 additions and 450 deletions.
5 changes: 5 additions & 0 deletions crates/proof-of-sql-parser/src/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ pub fn count(expr: Box<Expression>) -> Box<Expression> {
})
}

/// Count the rows
pub fn count_all() -> Box<Expression> {
count(Box::new(Expression::Wildcard))
}

/// An expression with an alias i.e. EXPR AS ALIAS
pub fn aliased_expr(expr: Box<Expression>, alias: &str) -> AliasedResultExpr {
AliasedResultExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/proof-of-sql/src/sql/parse/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ pub enum ConversionError {
#[error(transparent)]
ColumnOperationError(#[from] ColumnOperationError),

/// Errors related to postprocessing
#[error(transparent)]
PostprocessingError(#[from] crate::sql::postprocessing::PostprocessingError),

#[error("Query not provable because: {0}")]
/// Query requires unprovable feature
Unprovable(String),
Expand Down
3 changes: 0 additions & 3 deletions crates/proof-of-sql/src/sql/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ mod query_expr_tests;
mod query_expr;
pub use query_expr::QueryExpr;

mod result_expr_builder;
pub(crate) use result_expr_builder::ResultExprBuilder;

mod filter_expr_builder;
pub(crate) use filter_expr_builder::FilterExprBuilder;

Expand Down
49 changes: 14 additions & 35 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
base::{
commitment::Commitment,
database::{ColumnRef, ColumnType, LiteralValue, TableRef},
database::{ColumnRef, LiteralValue, TableRef},
},
sql::{
ast::{AliasedProvableExprPlan, ColumnExpr, GroupByExpr, ProvableExprPlan, TableExpr},
Expand Down Expand Up @@ -91,6 +91,10 @@ impl QueryContext {
self.in_agg_scope
}

pub(crate) fn has_agg(&self) -> bool {
self.agg_counter > 0 || !self.group_by_exprs.is_empty()
}

pub fn push_column_ref(&mut self, column: Identifier, column_ref: ColumnRef) {
self.col_ref_counter += 1;
self.push_result_column_ref(column);
Expand Down Expand Up @@ -130,18 +134,6 @@ impl QueryContext {
self.order_by_exprs = order_by_exprs;
}

pub fn get_any_result_column_ref(&self) -> Option<(Identifier, ColumnType)> {
// For tests to work we need to make it deterministic by sorting the columns
// In the long run we simply need to let * be *
// and get rid of this workaround altogether
let mut columns = self.result_column_set.iter().collect::<Vec<_>>();
columns.sort();
columns.first().map(|c| {
let column = self.column_mapping[*c];
(column.column_id(), *column.column_type())
})
}

pub fn is_in_group_by_exprs(&self, column: &Identifier) -> ConversionResult<bool> {
// Non-aggregated result column references must be included in the group by statement.
if self.group_by_exprs.is_empty() || self.is_in_agg_scope() || !self.is_in_result_scope() {
Expand Down Expand Up @@ -263,18 +255,8 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
.iter()
.zip(res_group_by_columns.iter())
.all(|(ident, res)| {
//TODO: This is due to a workaround related to polars
//Need to remove it when possible (PROOF-850)
if let Expression::Aggregation {
op: AggregationOperator::First,
expr,
} = (*res.expr).clone()
{
if let Expression::Column(res_ident) = *expr {
res_ident == *ident
} else {
false
}
if let Expression::Column(res_ident) = *res.expr {
res_ident == *ident
} else {
false
}
Expand Down Expand Up @@ -305,16 +287,13 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {

// Check count(*)
let count_column = &value.res_aliased_exprs[num_result_columns - 1];
let count_column_compliant = if let Expression::Aggregation {
op: AggregationOperator::Count,
expr,
} = (*count_column.expr).clone()
{
//TODO: This is due to a workaround related to polars
matches!(*expr, Expression::Column(_))
} else {
false
};
let count_column_compliant = matches!(
*count_column.expr,
Expression::Aggregation {
op: AggregationOperator::Count,
..
}
);

if !group_by_compliance || sum_expr.is_none() || !count_column_compliant {
return Ok(None);
Expand Down
48 changes: 10 additions & 38 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use proof_of_sql_parser::{
},
Identifier, ResourceId,
};
use std::ops::Deref;

pub struct QueryContextBuilder<'a> {
context: QueryContext,
Expand Down Expand Up @@ -117,65 +116,38 @@ impl<'a> QueryContextBuilder<'a> {
Ok(())
}

fn visit_aliased_expr(&mut self, mut aliased_expr: AliasedResultExpr) -> ConversionResult<()> {
self.visit_expr(aliased_expr.expr.as_mut())?;
fn visit_aliased_expr(&mut self, aliased_expr: AliasedResultExpr) -> ConversionResult<()> {
self.visit_expr(&aliased_expr.expr)?;
self.context.push_aliased_result_expr(aliased_expr)?;
Ok(())
}

/// Visits the expression and returns its data type.
///
/// This function accepts the expression as a mutable reference because certain expressions
/// require replacement, such as `count(*)` being replaced with `count(some_column)`.
fn visit_expr(&mut self, expr: &mut Expression) -> ConversionResult<ColumnType> {
fn visit_expr(&mut self, expr: &Expression) -> ConversionResult<ColumnType> {
match expr {
Expression::Wildcard => self.visit_wildcard_expr(expr),
Expression::Literal(literal) => self.visit_literal(literal.deref()),
Expression::Wildcard => Ok(ColumnType::BigInt), // Since COUNT(*) = COUNT(1)
Expression::Literal(literal) => self.visit_literal(literal),
Expression::Column(_) => self.visit_column_expr(expr),
Expression::Unary { op, expr } => self.visit_unary_expr(op, expr),
Expression::Binary { op, left, right } => self.visit_binary_expr(op, left, right),
Expression::Aggregation { op, expr } => self.visit_agg_expr(op, expr),
}
}

//TODO: Actually support multicolumn expressions
fn visit_wildcard_expr(&mut self, expr: &mut Expression) -> ConversionResult<ColumnType> {
let (col_name, col_type) = match self.context.get_any_result_column_ref() {
Some((name, col_type)) => (name, col_type),
None => self.lookup_schema().into_iter().next().unwrap(),
};

// Replace `count(*)` with `count(col_name)` to overcome limitations in Polars.
*expr = Expression::Column(col_name);

// Visit the column to ensure its inclusion in the result column set.
self.visit_column_expr(expr)?;

// Return the column type
Ok(col_type)
}

fn visit_column_expr(&mut self, expr: &mut Expression) -> ConversionResult<ColumnType> {
fn visit_column_expr(&mut self, expr: &Expression) -> ConversionResult<ColumnType> {
let identifier = match expr {
Expression::Column(identifier) => *identifier,
_ => panic!("Must be a column expression"),
};

// When using `group by` clauses, result columns outside aggregation
// need to be remapped to an aggregation function. This prevents Polars
// from returning lists when the expected result is single elements.
if self.context.is_in_group_by_exprs(&identifier)? {
*expr = *Expression::Column(identifier).first();
}

self.visit_column_identifier(identifier)
}

fn visit_binary_expr(
&mut self,
op: &BinaryOperator,
left: &mut Expression,
right: &mut Expression,
left: &Expression,
right: &Expression,
) -> ConversionResult<ColumnType> {
let left_dtype = self.visit_expr(left)?;
let right_dtype = self.visit_expr(right)?;
Expand All @@ -196,7 +168,7 @@ impl<'a> QueryContextBuilder<'a> {
fn visit_unary_expr(
&mut self,
op: &UnaryOperator,
expr: &mut Expression,
expr: &Expression,
) -> ConversionResult<ColumnType> {
match op {
UnaryOperator::Not => {
Expand All @@ -215,7 +187,7 @@ impl<'a> QueryContextBuilder<'a> {
fn visit_agg_expr(
&mut self,
op: &AggregationOperator,
expr: &mut Expression,
expr: &Expression,
) -> ConversionResult<ColumnType> {
self.context.set_in_agg_scope(true)?;

Expand Down
Loading

0 comments on commit 0f008d7

Please sign in to comment.