From 6e4526586ebe65a7ecfb5ace6bf13e8aaf53500e Mon Sep 17 00:00:00 2001 From: Ian Alexander Joiner <14581281+iajoiner@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:31:16 -0400 Subject: [PATCH] feat: impl `OwnedTablePostprocessingStep` for `GroupByPostprocessing` (#86) # Rationale for this change Please review #94 before reviewing #86 or alternatively check out the second commit. # What changes are included in this PR? Add `GroupByPostprocessing` # Are these changes tested? Yes --- .../src/base/database/group_by_util.rs | 3 +- .../src/base/database/owned_column_error.rs | 2 +- .../src/sql/postprocessing/error.rs | 6 + .../postprocessing/group_by_postprocessing.rs | 128 ++++++++++- .../group_by_postprocessing_test.rs | 202 +++++++++++++++++- .../owned_table_postprocessing.rs | 11 +- .../src/sql/postprocessing/test_utility.rs | 16 +- 7 files changed, 359 insertions(+), 9 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/group_by_util.rs b/crates/proof-of-sql/src/base/database/group_by_util.rs index 50713c644..5ee3c03c4 100644 --- a/crates/proof-of-sql/src/base/database/group_by_util.rs +++ b/crates/proof-of-sql/src/base/database/group_by_util.rs @@ -11,6 +11,7 @@ use rayon::prelude::ParallelSliceMut; use thiserror::Error; /// The output of the `aggregate_columns` function. +#[derive(Debug)] pub struct AggregatedColumns<'a, S: Scalar> { /// The columns that are being grouped by. These are all unique and correspond to each group. /// This is effectively just the original group_by columns filtered by the selection. @@ -26,7 +27,7 @@ pub struct AggregatedColumns<'a, S: Scalar> { /// The number of rows in each group. pub count_column: &'a [i64], } -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq, Eq)] pub enum AggregateColumnsError { #[error("Column length mismatch")] ColumnLengthMismatch, diff --git a/crates/proof-of-sql/src/base/database/owned_column_error.rs b/crates/proof-of-sql/src/base/database/owned_column_error.rs index 0e247f5e8..fc7f4b1b8 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_error.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_error.rs @@ -2,7 +2,7 @@ use crate::base::database::ColumnType; use thiserror::Error; /// Errors from operations related to `OwnedColumn`s. -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq, Eq)] pub enum OwnedColumnError { /// Can not perform type casting. #[error("Can not perform type casting from {from_type:?} to {to_type:?}")] diff --git a/crates/proof-of-sql/src/sql/postprocessing/error.rs b/crates/proof-of-sql/src/sql/postprocessing/error.rs index 43e68d9ea..d880c8dff 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/error.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/error.rs @@ -19,6 +19,12 @@ pub enum PostprocessingError { /// GROUP BY clause references a column not in a group by expression outside aggregate functions #[error("Invalid group by: column '{0}' must not appear outside aggregate functions or `GROUP BY` clause.")] IdentifierNotInAggregationOperatorOrGroupByClause(Identifier), + /// Errors in aggregate columns + #[error(transparent)] + AggregateColumnsError(#[from] crate::base::database::group_by_util::AggregateColumnsError), + /// Errors in `OwnedColumn` + #[error(transparent)] + OwnedColumnError(#[from] crate::base::database::OwnedColumnError), /// Nested aggregation in `GROUP BY` clause #[error("Nested aggregation in `GROUP BY` clause: {0}")] NestedAggregationInGroupByClause(String), diff --git a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs index 3ddf19c4c..87f849fca 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs @@ -1,10 +1,17 @@ -use super::{PostprocessingError, PostprocessingResult}; +use super::{PostprocessingError, PostprocessingResult, PostprocessingStep}; +use crate::base::{ + database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable}, + scalar::Scalar, +}; +use bumpalo::Bump; use indexmap::{IndexMap, IndexSet}; +use itertools::{izip, Itertools}; use proof_of_sql_parser::{ intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression}, Identifier, }; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; /// A group by expression #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -146,9 +153,10 @@ impl GroupByPostprocessing { ) }) .collect::>>()?; + let group_by_identifiers = by_ids.into_iter().unique().collect(); Ok(Self { remainder_exprs, - group_by_identifiers: by_ids, + group_by_identifiers, aggregation_expr_map, }) } @@ -169,6 +177,122 @@ impl GroupByPostprocessing { } } +impl PostprocessingStep for GroupByPostprocessing { + /// Apply the group by transformation to the given `OwnedTable`. + fn apply(&self, owned_table: OwnedTable) -> PostprocessingResult> { + // First evaluate all the aggregated columns + let alloc = Bump::new(); + let evaluated_columns: HashMap)>> = + self.aggregation_expr_map + .iter() + .map(|((agg_op, expr), id)| -> PostprocessingResult<_> { + let evaluated_owned_column = owned_table.evaluate(expr)?; + Ok((*agg_op, (*id, evaluated_owned_column))) + }) + .process_results(|iter| iter.into_group_map())?; + // Next actually do the GROUP BY + let group_by_ins = self + .group_by_identifiers + .iter() + .map(|id| { + let column = owned_table + .inner_table() + .get(id) + .ok_or(PostprocessingError::ColumnNotFound(id.to_string()))?; + Ok(Column::::from_owned_column(column, &alloc)) + }) + .collect::>>()?; + // TODO: Allow a filter + let selection_in = vec![true; owned_table.num_rows()]; + let (sum_ids, sum_ins): (Vec<_>, Vec<_>) = evaluated_columns + .get(&AggregationOperator::Sum) + .map(|tuple| { + tuple + .iter() + .map(|(id, c)| (*id, Column::::from_owned_column(c, &alloc))) + .unzip() + }) + .unwrap_or((vec![], vec![])); + let (max_ids, max_ins): (Vec<_>, Vec<_>) = evaluated_columns + .get(&AggregationOperator::Max) + .map(|tuple| { + tuple + .iter() + .map(|(id, c)| (*id, Column::::from_owned_column(c, &alloc))) + .unzip() + }) + .unwrap_or((vec![], vec![])); + let (min_ids, min_ins): (Vec<_>, Vec<_>) = evaluated_columns + .get(&AggregationOperator::Min) + .map(|tuple| { + tuple + .iter() + .map(|(id, c)| (*id, Column::::from_owned_column(c, &alloc))) + .unzip() + }) + .unwrap_or((vec![], vec![])); + let aggregation_results = aggregate_columns( + &alloc, + &group_by_ins, + &sum_ins, + &max_ins, + &min_ins, + &selection_in, + )?; + // Finally do another round of evaluation to get the final result + // Gather the results into a new OwnedTable + let group_by_outs = aggregation_results + .group_by_columns + .iter() + .zip(self.group_by_identifiers.iter()) + .map(|(column, id)| Ok((*id, OwnedColumn::from(column)))); + let sum_outs = + izip!(aggregation_results.sum_columns, sum_ids, sum_ins,).map(|(c_out, id, c_in)| { + Ok(( + id, + OwnedColumn::try_from_scalars(c_out, c_in.column_type())?, + )) + }); + let max_outs = + izip!(aggregation_results.max_columns, max_ids, max_ins,).map(|(c_out, id, c_in)| { + Ok(( + id, + OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?, + )) + }); + let min_outs = + izip!(aggregation_results.min_columns, min_ids, min_ins,).map(|(c_out, id, c_in)| { + Ok(( + id, + OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?, + )) + }); + //TODO: When we have NULLs we need to differentiate between count(1) and count(expression) + let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec()); + let count_outs = evaluated_columns + .get(&AggregationOperator::Count) + .into_iter() + .flatten() + .map(|(id, _)| -> PostprocessingResult<_> { Ok((*id, count_column.clone())) }); + let new_owned_table: OwnedTable = group_by_outs + .into_iter() + .chain(sum_outs) + .chain(max_outs) + .chain(min_outs) + .chain(count_outs) + .process_results(|iter| OwnedTable::try_from_iter(iter))??; + let res = self + .remainder_exprs + .iter() + .map(|aliased_expr| -> PostprocessingResult<_> { + let column = new_owned_table.evaluate(&aliased_expr.expr)?; + Ok((aliased_expr.alias, column)) + }) + .process_results(|iter| OwnedTable::try_from_iter(iter))??; + Ok(res) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs index 1f68b9cbc..9fdf8012a 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs @@ -1,6 +1,17 @@ -use crate::sql::postprocessing::{group_by_postprocessing::*, PostprocessingError}; +use crate::{ + base::{ + database::{owned_table_utility::*, OwnedTable}, + scalar::Curve25519Scalar, + }, + sql::postprocessing::{ + apply_postprocessing_steps, group_by_postprocessing::*, test_utility::*, + OwnedTablePostprocessing, PostprocessingError, + }, +}; use indexmap::indexmap; -use proof_of_sql_parser::{intermediate_ast::AggregationOperator, utility::*}; +use proof_of_sql_parser::{ + intermediate_ast::AggregationOperator, intermediate_decimal::IntermediateDecimal, utility::*, +}; #[test] fn we_cannot_have_invalid_group_bys() { @@ -48,3 +59,190 @@ fn we_can_make_group_by_postprocessing() { } ); } + +#[test] +fn we_can_do_simple_group_bys() { + // SELECT MAX(a) as max_a, MIN(b) as min_b, SUM(c) as sum_c, COUNT(d) as count_d FROM tab + let table: OwnedTable = owned_table([ + int128("a", [1_i128, 2, 3, 4]), + bigint("b", [5_i64, 6, 7, 8]), + smallint("c", [9_i16, 10, 11, 12]), + varchar("d", ["Space", "and", "Time", "rocks"]), + ]); + let postprocessing: [OwnedTablePostprocessing; 1] = [group_by_postprocessing( + &[], + &[ + aliased_expr(max(col("a")), "max_a"), + aliased_expr(min(col("b")), "min_b"), + aliased_expr(sum(col("c")), "sum_c"), + aliased_expr(count(col("d")), "count_d"), + ], + )]; + let expected_table = owned_table([ + int128("max_a", [4_i128]), + bigint("min_b", [5_i64]), + smallint("sum_c", [42_i16]), + bigint("count_d", [4_i64]), + ]); + let actual_table = apply_postprocessing_steps(table, &postprocessing).unwrap(); + assert_eq!(actual_table, expected_table); + + // SELECT a, MIN(b) as min_b, SUM(c) as sum_c, COUNT(d) as count_d FROM tab GROUP BY a + let table: OwnedTable = owned_table([ + int128("a", [1_i128, 1, 2, 2]), + bigint("b", [5_i64, 6, 7, 8]), + smallint("c", [9_i16, 10, 11, 12]), + varchar("d", ["Space", "and", "Time", "rocks"]), + ]); + let postprocessing: [OwnedTablePostprocessing; 1] = [group_by_postprocessing( + &["a"], + &[ + aliased_expr(col("a"), "a"), + aliased_expr(min(col("b")), "min_b"), + aliased_expr(sum(col("c")), "sum_c"), + aliased_expr(count(col("d")), "count_d"), + ], + )]; + let expected_table = owned_table([ + int128("a", [1_i128, 2]), + bigint("min_b", [5_i64, 7]), + smallint("sum_c", [19_i16, 23]), + bigint("count_d", [2_i64, 2]), + ]); + let actual_table = apply_postprocessing_steps(table, &postprocessing).unwrap(); + assert_eq!(actual_table, expected_table); + + // SELECT a + b as res, SUM(c) as sum_c, COUNT(d) as count_d FROM tab GROUP BY a, b, a, b, b + let table: OwnedTable = owned_table([ + int128("a", [1_i128, 5, 5, 1]), + bigint("b", [1_i64, 2, 2, 2]), + smallint("c", [9_i16, 11, 12, 10]), + varchar("d", ["Space", "and", "Time", "rocks"]), + ]); + let postprocessing: [OwnedTablePostprocessing; 1] = [group_by_postprocessing( + &["a", "b", "a", "b", "b"], + &[ + aliased_expr(add(col("a"), col("b")), "res"), + aliased_expr(sum(col("c")), "sum_c"), + aliased_expr(count(col("d")), "count_d"), + ], + )]; + let expected_table: OwnedTable = owned_table([ + int128("res", [2_i128, 3, 7]), + smallint("sum_c", [9_i16, 10, 23]), + bigint("count_d", [1_i64, 1, 2]), + ]); + let actual_table = apply_postprocessing_steps(table, &postprocessing).unwrap(); + assert_eq!(actual_table, expected_table); +} + +#[test] +fn we_can_do_complex_group_bys() { + // SELECT 2 * MAX(2 * a + 1) as max_a, MIN(b + 4) - 2.4 as min_b, SUM(c * 1.4) as sum_c, COUNT(d) + 3 as count_d FROM tab + let table: OwnedTable = owned_table([ + int128("a", [1_i128, 2, 3, 4]), + bigint("b", [5_i64, 6, 7, 8]), + smallint("c", [9_i16, 10, 11, 12]), + varchar("d", ["Space", "and", "Time", "rocks"]), + ]); + let postprocessing: [OwnedTablePostprocessing; 1] = [group_by_postprocessing( + &[], + &[ + aliased_expr( + mul(lit(2), max(add(mul(lit(2), col("a")), lit(1)))), + "max_a", + ), + aliased_expr( + sub( + min(add(col("b"), lit(4))), + lit("2.4".parse::().unwrap()), + ), + "min_b", + ), + aliased_expr( + sum(mul( + col("c"), + lit("1.4".parse::().unwrap()), + )), + "sum_c", + ), + aliased_expr(add(count(col("d")), lit(3)), "count_d"), + ], + )]; + let expected_table = owned_table([ + int128("max_a", [18_i128]), + decimal75("min_b", 21, 1, [66]), + decimal75("sum_c", 8, 1, [588]), + bigint("count_d", [7_i64]), + ]); + let actual_table = apply_postprocessing_steps(table, &postprocessing).unwrap(); + assert_eq!(actual_table, expected_table); + + // SELECT count(a + 2.5) + 2 as count_a, 2 * (MAX(2 * c + 1) + SUM(2.5 * d)) as res, count(d) - 1 as count_d_alt, MIN(b + 2.4) - 3.4 as min_b, SUM(c * 1.7) as sum_c, COUNT(d) - 3 as count_d, COUNT(e) as count_e FROM tab group by a, a, a, a + let table: OwnedTable = owned_table([ + int128("a", [1_i128, 1, 1, 2]), + bigint("b", [5_i64, 6, 7, 8]), + smallint("c", [9_i16, 10, 11, 12]), + decimal75("d", 2, 1, [13, 14, 15, 16]), + varchar("e", ["Space", "and", "Time", "rocks"]), + ]); + let postprocessing: [OwnedTablePostprocessing; 1] = [group_by_postprocessing( + &["a", "a", "a", "a"], + &[ + aliased_expr( + add( + count(add( + col("a"), + lit("2.5".parse::().unwrap()), + )), + lit(2), + ), + "count_a", + ), + aliased_expr( + mul( + lit(2), + add( + max(add(mul(lit(2), col("c")), lit(1))), + sum(mul( + lit("2.5".parse::().unwrap()), + col("d"), + )), + ), + ), + "res", + ), + aliased_expr(sub(count(col("d")), lit(1)), "count_d_alt"), + aliased_expr( + sub( + min(add( + col("b"), + lit("2.4".parse::().unwrap()), + )), + lit("3.4".parse::().unwrap()), + ), + "min_b", + ), + aliased_expr( + sum(mul( + col("c"), + lit("1.7".parse::().unwrap()), + )), + "sum_c", + ), + aliased_expr(sub(count(col("d")), lit(3)), "count_d"), + aliased_expr(count(col("e")), "count_e"), + ], + )]; + let expected_table = owned_table([ + bigint("count_a", [5_i64, 3]), + decimal75("res", 42, 2, [6700, 5800]), + bigint("count_d_alt", [2_i64, 0]), + decimal75("min_b", 22, 1, [40, 70]), + decimal75("sum_c", 8, 1, [510, 204]), + bigint("count_d", [0_i64, -2]), + bigint("count_e", [3_i64, 1]), + ]); + let actual_table = apply_postprocessing_steps(table, &postprocessing).unwrap(); + assert_eq!(actual_table, expected_table); +} diff --git a/crates/proof-of-sql/src/sql/postprocessing/owned_table_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/owned_table_postprocessing.rs index 7dfa8e04a..6129c388f 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/owned_table_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/owned_table_postprocessing.rs @@ -1,6 +1,6 @@ use super::{ - OrderByPostprocessing, PostprocessingResult, PostprocessingStep, SelectPostprocessing, - SlicePostprocessing, + GroupByPostprocessing, OrderByPostprocessing, PostprocessingResult, PostprocessingStep, + SelectPostprocessing, SlicePostprocessing, }; use crate::base::{database::OwnedTable, scalar::Scalar}; @@ -13,6 +13,8 @@ pub enum OwnedTablePostprocessing { OrderBy(OrderByPostprocessing), /// Select the `OwnedTable` with the given `SelectPostprocessing`. Select(SelectPostprocessing), + /// Aggregate the `OwnedTable` with the given `GroupByPostprocessing`. + GroupBy(GroupByPostprocessing), } impl PostprocessingStep for OwnedTablePostprocessing { @@ -22,6 +24,7 @@ impl PostprocessingStep for OwnedTablePostprocessing { OwnedTablePostprocessing::Slice(slice_expr) => slice_expr.apply(owned_table), OwnedTablePostprocessing::OrderBy(order_by_expr) => order_by_expr.apply(owned_table), OwnedTablePostprocessing::Select(select_expr) => select_expr.apply(owned_table), + OwnedTablePostprocessing::GroupBy(group_by_expr) => group_by_expr.apply(owned_table), } } } @@ -39,6 +42,10 @@ impl OwnedTablePostprocessing { pub fn new_select(select_expr: SelectPostprocessing) -> Self { Self::Select(select_expr) } + /// Create a new `OwnedTablePostprocessing` with the given `GroupByPostprocessing`. + pub fn new_group_by(group_by_postprocessing: GroupByPostprocessing) -> Self { + Self::GroupBy(group_by_postprocessing) + } } /// Apply a list of postprocessing steps to an `OwnedTable`. diff --git a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs index 144aef2fa..9519a6e68 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs @@ -1,5 +1,19 @@ use super::*; -use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection}; +use proof_of_sql_parser::{ + intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection}, + utility::ident, + Identifier, +}; + +pub fn group_by_postprocessing( + cols: &[&str], + result_exprs: &[AliasedResultExpr], +) -> OwnedTablePostprocessing { + let ids: Vec = cols.iter().map(|col| ident(col)).collect(); + OwnedTablePostprocessing::new_group_by( + GroupByPostprocessing::try_new(ids, result_exprs.to_vec()).unwrap(), + ) +} pub fn select_expr(result_exprs: &[AliasedResultExpr]) -> OwnedTablePostprocessing { OwnedTablePostprocessing::new_select(SelectPostprocessing::new(result_exprs.to_vec()))