From c4d35e6a63005efde113536c097828b23dc21b0a Mon Sep 17 00:00:00 2001
From: Ian Joiner <14581281+iajoiner@users.noreply.github.com>
Date: Tue, 26 Nov 2024 16:51:44 -0500
Subject: [PATCH] feat: add `sort_merge_join`
---
.../src/base/database/join_util.rs | 449 +++++++++++++++++-
crates/proof-of-sql/src/base/database/mod.rs | 1 +
.../base/database/table_operation_error.rs | 18 +-
3 files changed, 465 insertions(+), 3 deletions(-)
diff --git a/crates/proof-of-sql/src/base/database/join_util.rs b/crates/proof-of-sql/src/base/database/join_util.rs
index 7cb614a9e..659370104 100644
--- a/crates/proof-of-sql/src/base/database/join_util.rs
+++ b/crates/proof-of-sql/src/base/database/join_util.rs
@@ -1,6 +1,17 @@
-use super::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOptions};
-use crate::base::scalar::Scalar;
+use super::{
+ apply_column_to_indexes,
+ order_by_util::{compare_indexes_by_columns, compare_single_row_of_tables},
+ Column, ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOperationError,
+ TableOperationResult, TableOptions,
+};
+use crate::base::{
+ map::{IndexMap, IndexSet},
+ scalar::Scalar,
+};
use bumpalo::Bump;
+use core::cmp::Ordering;
+use itertools::Itertools;
+use proof_of_sql_parser::Identifier;
/// Compute the CROSS JOIN / cartesian product of two tables.
///
@@ -34,6 +45,145 @@ pub fn cross_join<'a, S: Scalar>(
.expect("Table creation should not fail")
}
+/// Compute the JOIN of two tables using a sort-merge join.
+///
+/// Currently we only support INNER JOINs and only support joins on equalities.
+/// # Panics
+/// The function panics if we feed in incorrect data (e.g. Num of rows in `left` and some column of `left_on` being different).
+#[allow(clippy::too_many_lines)]
+pub fn sort_merge_join<'a, S: Scalar>(
+ left: &Table<'a, S>,
+ right: &Table<'a, S>,
+ left_on: &[Column<'a, S>],
+ right_on: &[Column<'a, S>],
+ left_selected_column_ident_aliases: &[(Identifier, Identifier)],
+ right_selected_column_ident_aliases: &[(Identifier, Identifier)],
+ alloc: &'a Bump,
+) -> TableOperationResult
> {
+ let left_num_rows = left.num_rows();
+ let right_num_rows = right.num_rows();
+ // Check that result aliases are unique
+ let aliases = left_selected_column_ident_aliases
+ .iter()
+ .map(|(_, alias)| alias)
+ .chain(
+ right_selected_column_ident_aliases
+ .iter()
+ .map(|(_, alias)| alias),
+ )
+ .collect::>();
+ if aliases.len()
+ != left_selected_column_ident_aliases.len() + right_selected_column_ident_aliases.len()
+ {
+ return Err(TableOperationError::DuplicateColumn);
+ }
+ // Check that the number of rows is good
+ for column in left_on {
+ assert_eq!(column.len(), left_num_rows);
+ }
+ for column in right_on {
+ assert_eq!(column.len(), right_num_rows);
+ }
+ // First of all sort the tables by the columns we are joining on
+ let left_indexes =
+ (0..left.num_rows()).sorted_unstable_by(|&a, &b| compare_indexes_by_columns(left_on, a, b));
+ let right_indexes = (0..right.num_rows())
+ .sorted_unstable_by(|&a, &b| compare_indexes_by_columns(right_on, a, b));
+ // Collect the indexes of the rows that match
+ let mut left_iter = left_indexes.into_iter().peekable();
+ let mut right_iter = right_indexes.into_iter().peekable();
+ let index_pairs = core::iter::from_fn(|| {
+ // If we have reached the end of either table, we are done
+ let (&left_index, &right_index) = (left_iter.peek()?, right_iter.peek()?);
+ match compare_single_row_of_tables(left_on, right_on, left_index, right_index).ok()? {
+ Ordering::Less => {
+ // Move left forward, return no pairs for this iteration
+ left_iter.next();
+ Some(Vec::new())
+ }
+ Ordering::Greater => {
+ // Move right forward, return no pairs for this iteration
+ right_iter.next();
+ Some(Vec::new())
+ }
+ Ordering::Equal => {
+ // Identify groups of equal keys from both sides
+ let left_group = left_iter
+ .clone()
+ .take_while(|&lidx| {
+ compare_indexes_by_columns(left_on, left_index, lidx) == Ordering::Equal
+ })
+ .collect::>();
+
+ let right_group = right_iter
+ .clone()
+ .take_while(|&ridx| {
+ compare_indexes_by_columns(right_on, right_index, ridx) == Ordering::Equal
+ })
+ .collect::>();
+
+ // All combinations of left_group x right_group
+ let pairs = left_group
+ .iter()
+ .cartesian_product(right_group.iter())
+ .map(|(&l, &r)| (l, r))
+ .collect::>();
+
+ // Advance the iterators past the groups
+ left_iter.nth(left_group.len() - 1);
+ right_iter.nth(right_group.len() - 1);
+
+ Some(pairs)
+ }
+ }
+ })
+ // Flatten out the Vec> from above into a single Vec
+ .flatten()
+ .collect::>();
+ // Now we have the indexes of the rows that match, we can create the new table
+ let (left_indexes, right_indexes): (Vec, Vec) = index_pairs.into_iter().unzip();
+ let num_rows = left_indexes.len();
+ let result_columns = left_selected_column_ident_aliases
+ .iter()
+ .map(
+ |(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> {
+ Ok((
+ *alias,
+ apply_column_to_indexes(
+ left.inner_table().get(ident).ok_or(
+ TableOperationError::ColumnDoesNotExist {
+ column_ident: *ident,
+ },
+ )?,
+ alloc,
+ &left_indexes,
+ )?,
+ ))
+ },
+ )
+ .chain(right_selected_column_ident_aliases.iter().map(
+ |(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> {
+ Ok((
+ *alias,
+ apply_column_to_indexes(
+ right.inner_table().get(ident).ok_or(
+ TableOperationError::ColumnDoesNotExist {
+ column_ident: *ident,
+ },
+ )?,
+ alloc,
+ &right_indexes,
+ )?,
+ ))
+ },
+ ))
+ .collect::>>()?;
+ Ok(
+ Table::<'a, S>::try_new_with_options(result_columns, TableOptions::new(Some(num_rows)))
+ .expect("Table creation should not fail"),
+ )
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -240,4 +390,299 @@ mod tests {
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 0);
}
+
+ #[test]
+ fn we_can_do_sort_merge_join_on_two_tables() {
+ let bump = Bump::new();
+ let a = "a".parse().unwrap();
+ let b = "b".parse().unwrap();
+ let c = "c".parse().unwrap();
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
+ (b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
+ (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])];
+ let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(c, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ )
+ .unwrap();
+ assert_eq!(result.num_rows(), 5);
+ assert_eq!(result.num_columns(), 3);
+ assert_eq!(
+ result.inner_table()[&a].as_smallint().unwrap(),
+ &[1_i16, 2, 2, 3, 3]
+ );
+ assert_eq!(
+ result.inner_table()[&b].as_int().unwrap(),
+ &[4_i32, 5, 5, 5, 5]
+ );
+ assert_eq!(
+ result.inner_table()[&c].as_bigint().unwrap(),
+ &[7_i64, 8, 9, 8, 9]
+ );
+ }
+
+ #[test]
+ fn we_can_do_sort_merge_join_on_two_tables_with_empty_results() {
+ let bump = Bump::new();
+ let a = "a".parse().unwrap();
+ let b = "b".parse().unwrap();
+ let c = "c".parse().unwrap();
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
+ (b, Column::Int(&[3_i32, 15, 9, 14, 15, 7])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
+ (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[3_i32, 15, 9, 14, 15, 7])];
+ let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(c, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ )
+ .unwrap();
+ assert_eq!(result.num_rows(), 0);
+ assert_eq!(result.num_columns(), 3);
+ assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
+ assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
+ assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);
+ }
+
+ #[allow(clippy::too_many_lines)]
+ #[test]
+ fn we_can_do_sort_merge_join_on_tables_with_no_rows() {
+ let bump = Bump::new();
+ let a = "a".parse().unwrap();
+ let b = "b".parse().unwrap();
+ let c = "c".parse().unwrap();
+
+ // Right table has no rows
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
+ (b, Column::Int(&[3_i32, 15, 9, 14, 15, 7])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[0_i64; 0])),
+ (b, Column::Int(&[0_i32; 0])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[3_i32, 15, 9, 14, 15, 7])];
+ let right_on = vec![Column::Int(&[0_i32; 0])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(c, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ )
+ .unwrap();
+ assert_eq!(result.num_rows(), 0);
+ assert_eq!(result.num_columns(), 3);
+ assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
+ assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
+ assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);
+
+ // Left table has no rows
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[0_i16; 0])),
+ (b, Column::Int(&[0_i32; 0])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
+ (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[0_i32; 0])];
+ let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(c, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ )
+ .unwrap();
+ assert_eq!(result.num_rows(), 0);
+ assert_eq!(result.num_columns(), 3);
+ assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
+ assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
+ assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);
+
+ // Both tables have no rows
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[0_i16; 0])),
+ (b, Column::Int(&[0_i32; 0])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[0_i64; 0])),
+ (b, Column::Int(&[0_i32; 0])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[0_i32; 0])];
+ let right_on = vec![Column::Int(&[0_i32; 0])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(c, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ )
+ .unwrap();
+ assert_eq!(result.num_rows(), 0);
+ assert_eq!(result.num_columns(), 3);
+ assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
+ assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
+ assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);
+ }
+
+ #[test]
+ fn we_can_not_do_sort_merge_join_with_duplicate_aliases() {
+ let bump = Bump::new();
+ let a = "a".parse().unwrap();
+ let b = "b".parse().unwrap();
+ let c = "c".parse().unwrap();
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
+ (b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
+ (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])];
+ let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(b, b), (c, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ );
+ assert_eq!(result, Err(TableOperationError::DuplicateColumn));
+ }
+
+ #[test]
+ fn we_can_not_do_sort_merge_join_with_wrong_column_identifiers() {
+ let bump = Bump::new();
+ let a = "a".parse().unwrap();
+ let b = "b".parse().unwrap();
+ let c = "c".parse().unwrap();
+ let not_a_column = "not_a_column".parse().unwrap();
+ let left = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
+ (b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let right = Table::<'_, TestScalar>::try_from_iter_with_options(
+ vec![
+ (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
+ (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
+ ],
+ TableOptions::default(),
+ )
+ .expect("Table creation should not fail");
+ let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])];
+ let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
+ let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
+ let right_selected_column_ident_aliases = vec![(not_a_column, c)];
+ let result = sort_merge_join(
+ &left,
+ &right,
+ &left_on,
+ &right_on,
+ &left_selected_column_ident_aliases,
+ &right_selected_column_ident_aliases,
+ &bump,
+ );
+ assert!(matches!(
+ result,
+ Err(TableOperationError::ColumnDoesNotExist { .. })
+ ));
+ }
}
diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs
index 5e259f6a7..03f80f23c 100644
--- a/crates/proof-of-sql/src/base/database/mod.rs
+++ b/crates/proof-of-sql/src/base/database/mod.rs
@@ -26,6 +26,7 @@ pub(super) use column_comparison_operation::{
};
mod column_index_operation;
+pub(super) use column_index_operation::apply_column_to_indexes;
mod column_repetition_operation;
pub(super) use column_repetition_operation::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp};
diff --git a/crates/proof-of-sql/src/base/database/table_operation_error.rs b/crates/proof-of-sql/src/base/database/table_operation_error.rs
index b0ec13d91..b631d9003 100644
--- a/crates/proof-of-sql/src/base/database/table_operation_error.rs
+++ b/crates/proof-of-sql/src/base/database/table_operation_error.rs
@@ -1,6 +1,7 @@
-use crate::base::database::{ColumnField, ColumnType};
+use super::{ColumnField, ColumnOperationError, ColumnType};
use alloc::vec::Vec;
use core::result::Result;
+use proof_of_sql_parser::Identifier;
use snafu::Snafu;
/// Errors from operations on tables.
@@ -26,6 +27,21 @@ pub enum TableOperationError {
/// The right-hand side data type
right_type: ColumnType,
},
+ /// Errors related to a column that does not exist in a table.
+ #[snafu(display("Column {column_ident:?} does not exist in table"))]
+ ColumnDoesNotExist {
+ /// The nonexistent column identifier
+ column_ident: Identifier,
+ },
+ /// Errors related to duplicate columns in a table.
+ #[snafu(display("Some column is duplicated in table"))]
+ DuplicateColumn,
+ /// Errors due to bad column operations.
+ #[snafu(transparent)]
+ ColumnOperationError {
+ /// The underlying `ColumnOperationError`
+ source: ColumnOperationError,
+ },
}
/// Result type for table operations