Skip to content

Commit

Permalink
feat: impl OwnedTablePostprocessingStep for GroupByPostprocessing (
Browse files Browse the repository at this point in the history
…#86)

# Rationale for this change
Please review #94 before reviewing #86 or alternatively check out the
second commit.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked Jira ticket then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->

# What changes are included in this PR?
Add `GroupByPostprocessing`
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

# Are these changes tested?
Yes
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
  • Loading branch information
iajoiner authored Aug 13, 2024
1 parent 1967d3b commit 6e45265
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 9 deletions.
3 changes: 2 additions & 1 deletion crates/proof-of-sql/src/base/database/group_by_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rayon::prelude::ParallelSliceMut;
use thiserror::Error;

/// The output of the `aggregate_columns` function.
#[derive(Debug)]
pub struct AggregatedColumns<'a, S: Scalar> {
/// The columns that are being grouped by. These are all unique and correspond to each group.
/// This is effectively just the original group_by columns filtered by the selection.
Expand All @@ -26,7 +27,7 @@ pub struct AggregatedColumns<'a, S: Scalar> {
/// The number of rows in each group.
pub count_column: &'a [i64],
}
#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq, Eq)]
pub enum AggregateColumnsError {
#[error("Column length mismatch")]
ColumnLengthMismatch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::base::database::ColumnType;
use thiserror::Error;

/// Errors from operations related to `OwnedColumn`s.
#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq, Eq)]
pub enum OwnedColumnError {
/// Can not perform type casting.
#[error("Can not perform type casting from {from_type:?} to {to_type:?}")]
Expand Down
6 changes: 6 additions & 0 deletions crates/proof-of-sql/src/sql/postprocessing/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ pub enum PostprocessingError {
/// GROUP BY clause references a column not in a group by expression outside aggregate functions
#[error("Invalid group by: column '{0}' must not appear outside aggregate functions or `GROUP BY` clause.")]
IdentifierNotInAggregationOperatorOrGroupByClause(Identifier),
/// Errors in aggregate columns
#[error(transparent)]
AggregateColumnsError(#[from] crate::base::database::group_by_util::AggregateColumnsError),
/// Errors in `OwnedColumn`
#[error(transparent)]
OwnedColumnError(#[from] crate::base::database::OwnedColumnError),
/// Nested aggregation in `GROUP BY` clause
#[error("Nested aggregation in `GROUP BY` clause: {0}")]
NestedAggregationInGroupByClause(String),
Expand Down
128 changes: 126 additions & 2 deletions crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use super::{PostprocessingError, PostprocessingResult};
use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
use crate::base::{
database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable},
scalar::Scalar,
};
use bumpalo::Bump;
use indexmap::{IndexMap, IndexSet};
use itertools::{izip, Itertools};
use proof_of_sql_parser::{
intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression},
Identifier,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// A group by expression
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -146,9 +153,10 @@ impl GroupByPostprocessing {
)
})
.collect::<PostprocessingResult<Vec<AliasedResultExpr>>>()?;
let group_by_identifiers = by_ids.into_iter().unique().collect();
Ok(Self {
remainder_exprs,
group_by_identifiers: by_ids,
group_by_identifiers,
aggregation_expr_map,
})
}
Expand All @@ -169,6 +177,122 @@ impl GroupByPostprocessing {
}
}

impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
/// Apply the group by transformation to the given `OwnedTable`.
fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
// First evaluate all the aggregated columns
let alloc = Bump::new();
let evaluated_columns: HashMap<AggregationOperator, Vec<(Identifier, OwnedColumn<S>)>> =
self.aggregation_expr_map
.iter()
.map(|((agg_op, expr), id)| -> PostprocessingResult<_> {
let evaluated_owned_column = owned_table.evaluate(expr)?;
Ok((*agg_op, (*id, evaluated_owned_column)))
})
.process_results(|iter| iter.into_group_map())?;
// Next actually do the GROUP BY
let group_by_ins = self
.group_by_identifiers
.iter()
.map(|id| {
let column = owned_table
.inner_table()
.get(id)
.ok_or(PostprocessingError::ColumnNotFound(id.to_string()))?;
Ok(Column::<S>::from_owned_column(column, &alloc))
})
.collect::<PostprocessingResult<Vec<_>>>()?;
// TODO: Allow a filter
let selection_in = vec![true; owned_table.num_rows()];
let (sum_ids, sum_ins): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Sum)
.map(|tuple| {
tuple
.iter()
.map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
.unzip()
})
.unwrap_or((vec![], vec![]));
let (max_ids, max_ins): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Max)
.map(|tuple| {
tuple
.iter()
.map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
.unzip()
})
.unwrap_or((vec![], vec![]));
let (min_ids, min_ins): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Min)
.map(|tuple| {
tuple
.iter()
.map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
.unzip()
})
.unwrap_or((vec![], vec![]));
let aggregation_results = aggregate_columns(
&alloc,
&group_by_ins,
&sum_ins,
&max_ins,
&min_ins,
&selection_in,
)?;
// Finally do another round of evaluation to get the final result
// Gather the results into a new OwnedTable
let group_by_outs = aggregation_results
.group_by_columns
.iter()
.zip(self.group_by_identifiers.iter())
.map(|(column, id)| Ok((*id, OwnedColumn::from(column))));
let sum_outs =
izip!(aggregation_results.sum_columns, sum_ids, sum_ins,).map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_scalars(c_out, c_in.column_type())?,
))
});
let max_outs =
izip!(aggregation_results.max_columns, max_ids, max_ins,).map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
))
});
let min_outs =
izip!(aggregation_results.min_columns, min_ids, min_ins,).map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
))
});
//TODO: When we have NULLs we need to differentiate between count(1) and count(expression)
let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec());
let count_outs = evaluated_columns
.get(&AggregationOperator::Count)
.into_iter()
.flatten()
.map(|(id, _)| -> PostprocessingResult<_> { Ok((*id, count_column.clone())) });
let new_owned_table: OwnedTable<S> = group_by_outs
.into_iter()
.chain(sum_outs)
.chain(max_outs)
.chain(min_outs)
.chain(count_outs)
.process_results(|iter| OwnedTable::try_from_iter(iter))??;
let res = self
.remainder_exprs
.iter()
.map(|aliased_expr| -> PostprocessingResult<_> {
let column = new_owned_table.evaluate(&aliased_expr.expr)?;
Ok((aliased_expr.alias, column))
})
.process_results(|iter| OwnedTable::try_from_iter(iter))??;
Ok(res)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 6e45265

Please sign in to comment.