Skip to content

Commit

Permalink
refactor: remove base::sqlparser::ident since into() is simpler (#…
Browse files Browse the repository at this point in the history
…437)

Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
This change simplifies the codebase since
`proof_of_sql::base::sqlparser::ident` only does one thing which is to
call `From<&str>` on `Ident`. Let's simplify the codebase here.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
See title.
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Yes.
  • Loading branch information
iajoiner authored Dec 16, 2024
2 parents be3853a + ce4274d commit 447f312
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 49 deletions.
1 change: 0 additions & 1 deletion crates/proof-of-sql/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ mod serialize;
pub(crate) use serialize::{impl_serde_for_ark_serde_checked, impl_serde_for_ark_serde_unchecked};
pub(crate) mod map;
pub(crate) mod slice_ops;
pub(crate) mod sqlparser;

mod rayon_cfg;
pub(crate) use rayon_cfg::if_rayon;
5 changes: 0 additions & 5 deletions crates/proof-of-sql/src/base/sqlparser.rs

This file was deleted.

41 changes: 20 additions & 21 deletions crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::{
database::{ColumnRef, ColumnType, LiteralValue, TestSchemaAccessor},
map::{indexmap, IndexMap},
math::decimal::Precision,
sqlparser::ident,
},
sql::{
parse::{ConversionError, QueryExpr, WhereExprBuilder},
Expand Down Expand Up @@ -33,59 +32,59 @@ fn get_column_mappings_for_testing() -> IndexMap<Ident, ColumnRef> {
let mut column_mapping = IndexMap::default();
// Setup column mapping
column_mapping.insert(
ident("boolean_column"),
ColumnRef::new(tab_ref, ident("boolean_column"), ColumnType::Boolean),
"boolean_column".into(),
ColumnRef::new(tab_ref, "boolean_column".into(), ColumnType::Boolean),
);
column_mapping.insert(
ident("decimal_column"),
"decimal_column".into(),
ColumnRef::new(
tab_ref,
ident("decimal_column"),
"decimal_column".into(),
ColumnType::Decimal75(Precision::new(7).unwrap(), 2),
),
);
column_mapping.insert(
ident("int128_column"),
ColumnRef::new(tab_ref, ident("int128_column"), ColumnType::Int128),
"int128_column".into(),
ColumnRef::new(tab_ref, "int128_column".into(), ColumnType::Int128),
);
column_mapping.insert(
ident("bigint_column"),
ColumnRef::new(tab_ref, ident("bigint_column"), ColumnType::BigInt),
"bigint_column".into(),
ColumnRef::new(tab_ref, "bigint_column".into(), ColumnType::BigInt),
);

column_mapping.insert(
ident("varchar_column"),
ColumnRef::new(tab_ref, ident("varchar_column"), ColumnType::VarChar),
"varchar_column".into(),
ColumnRef::new(tab_ref, "varchar_column".into(), ColumnType::VarChar),
);
column_mapping.insert(
ident("timestamp_second_column"),
"timestamp_second_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_second_column"),
"timestamp_second_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()),
),
);
column_mapping.insert(
ident("timestamp_millisecond_column"),
"timestamp_millisecond_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_millisecond_column"),
"timestamp_millisecond_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Millisecond, PoSQLTimeZone::utc()),
),
);
column_mapping.insert(
ident("timestamp_microsecond_column"),
"timestamp_microsecond_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_microsecond_column"),
"timestamp_microsecond_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Microsecond, PoSQLTimeZone::utc()),
),
);
column_mapping.insert(
ident("timestamp_nanosecond_column"),
"timestamp_nanosecond_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_nanosecond_column"),
"timestamp_nanosecond_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Nanosecond, PoSQLTimeZone::utc()),
),
);
Expand Down Expand Up @@ -147,7 +146,7 @@ fn we_can_directly_check_whether_bigint_columns_ge_int128() {
let expected = DynProofExpr::try_new_inequality(
DynProofExpr::Column(ColumnExpr::new(ColumnRef::new(
"sxt.sxt_tab".parse().unwrap(),
ident("bigint_column"),
"bigint_column".into(),
ColumnType::BigInt,
))),
DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))),
Expand All @@ -169,7 +168,7 @@ fn we_can_directly_check_whether_bigint_columns_le_int128() {
let expected = DynProofExpr::try_new_inequality(
DynProofExpr::Column(ColumnExpr::new(ColumnRef::new(
"sxt.sxt_tab".parse().unwrap(),
ident("bigint_column"),
"bigint_column".into(),
ColumnType::BigInt,
))),
DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
#[cfg(test)]
mod tests {
use super::*;
use crate::base::sqlparser::ident;
use proof_of_sql_parser::utility::*;

#[test]
Expand Down Expand Up @@ -400,13 +399,13 @@ mod tests {

// a + b + 1
let expr = add(add(col("a"), col("b")), lit(1));
let expected: IndexSet<Ident> = [ident("a"), ident("b")].into_iter().collect();
let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);

// ! (a == b || c >= a)
let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
let expected: IndexSet<Ident> = [ident("a"), ident("b"), ident("c")].into_iter().collect();
let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);

Expand All @@ -418,7 +417,7 @@ mod tests {

// (COUNT(a + b) + c) * d
let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
let expected: IndexSet<Ident> = [ident("c"), ident("d")].into_iter().collect();
let expected: IndexSet<Ident> = ["c".into(), "d".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
}
Expand All @@ -433,7 +432,7 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
ident("__col_agg_0")
"__col_agg_0".into()
);
assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b"))));
assert_eq!(aggregation_expr_map.len(), 1);
Expand All @@ -444,11 +443,11 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
ident("__col_agg_0")
"__col_agg_0".into()
);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
ident("__col_agg_1")
"__col_agg_1".into()
);
assert_eq!(
remainder_expr,
Expand All @@ -468,14 +467,14 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
ident("__col_agg_2")
"__col_agg_2".into()
);
assert_eq!(
aggregation_expr_map[&(
AggregationOperator::Min,
*sub(mul(lit(2), col("b")), lit(4))
)],
ident("__col_agg_3")
"__col_agg_3".into()
);
assert_eq!(
remainder_expr,
Expand All @@ -492,7 +491,7 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
ident("__col_agg_4")
"__col_agg_4".into()
);
assert_eq!(
remainder_expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
base::{
database::{owned_table_utility::*, OwnedTable},
scalar::Curve25519Scalar,
sqlparser::ident,
},
sql::postprocessing::{
apply_postprocessing_steps, group_by_postprocessing::*, test_utility::*,
Expand All @@ -15,15 +14,15 @@ use proof_of_sql_parser::{intermediate_ast::AggregationOperator, utility::*};
fn we_cannot_have_invalid_group_bys() {
// Column in result but not in group by or aggregation
let expr = add(sum(col("a")), col("b")); // b is not in group by or aggregation
let res = GroupByPostprocessing::try_new(vec![ident("a")], vec![aliased_expr(expr, "res")]);
let res = GroupByPostprocessing::try_new(vec!["a".into()], vec![aliased_expr(expr, "res")]);
assert!(matches!(
res,
Err(PostprocessingError::IdentifierNotInAggregationOperatorOrGroupByClause { .. })
));

// Nested aggregation
let expr = sum(max(col("a"))); // Nested aggregation
let res = GroupByPostprocessing::try_new(vec![ident("a")], vec![aliased_expr(expr, "res")]);
let res = GroupByPostprocessing::try_new(vec!["a".into()], vec![aliased_expr(expr, "res")]);
assert!(matches!(
res,
Err(PostprocessingError::NestedAggregationInGroupByClause { .. })
Expand All @@ -34,14 +33,14 @@ fn we_cannot_have_invalid_group_bys() {
fn we_can_make_group_by_postprocessing() {
// SELECT SUM(a) + 2 as c0, SUM(b + a) as c1 FROM tab GROUP BY a, b
let res = GroupByPostprocessing::try_new(
vec![ident("a"), ident("b")],
vec!["a".into(), "b".into()],
vec![
aliased_expr(add(sum(col("a")), lit(2)), "c0"),
aliased_expr(sum(add(col("b"), col("a"))), "c1"),
],
)
.unwrap();
assert_eq!(res.group_by(), &[ident("a"), ident("b")]);
assert_eq!(res.group_by(), &["a".into(), "b".into()]);
assert_eq!(
res.remainder_exprs(),
&[
Expand All @@ -52,11 +51,11 @@ fn we_can_make_group_by_postprocessing() {
assert_eq!(
res.aggregation_exprs(),
&[
(AggregationOperator::Sum, *col("a"), ident("__col_agg_0")),
(AggregationOperator::Sum, *col("a"), "__col_agg_0".into()),
(
AggregationOperator::Sum,
*add(col("b"), col("a")),
ident("__col_agg_1")
"__col_agg_1".into()
),
]
);
Expand Down
3 changes: 1 addition & 2 deletions crates/proof-of-sql/src/sql/postprocessing/test_utility.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::*;
use crate::base::sqlparser::ident;
use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection};
use sqlparser::ast::Ident;

Expand All @@ -8,7 +7,7 @@ pub fn group_by_postprocessing(
cols: &[&str],
result_exprs: &[AliasedResultExpr],
) -> OwnedTablePostprocessing {
let ids: Vec<Ident> = cols.iter().map(|col| ident(col)).collect();
let ids: Vec<Ident> = cols.iter().map(|col| (*col).into()).collect();
OwnedTablePostprocessing::new_group_by(
GroupByPostprocessing::try_new(ids, result_exprs.to_vec()).unwrap(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::{
map::{indexset, IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
sqlparser::ident,
},
sql::proof::{FirstRoundBuilder, QueryData},
};
Expand All @@ -36,7 +35,7 @@ impl ProverEvaluate for EmptyTestQueryExpr {
builder.produce_one_evaluation_length(self.length);
table_with_row_count(
(1..=self.columns)
.map(|i| borrowed_bigint(ident(format!("a{i}").as_str()), zeros.clone(), alloc)),
.map(|i| borrowed_bigint(format!("a{i}").as_str(), zeros.clone(), alloc)),
self.length,
)
}
Expand All @@ -54,7 +53,7 @@ impl ProverEvaluate for EmptyTestQueryExpr {
.collect::<Vec<_>>();
table_with_row_count(
(1..=self.columns)
.map(|i| borrowed_bigint(ident(format!("a{i}").as_str()), zeros.clone(), alloc)),
.map(|i| borrowed_bigint(format!("a{i}").as_str(), zeros.clone(), alloc)),
self.length,
)
}
Expand Down

0 comments on commit 447f312

Please sign in to comment.