Skip to content

Commit

Permalink
refactor!: put arrow behind a feature flag (#111)
Browse files Browse the repository at this point in the history
# Rationale for this change
Please review #101 and #84 first.
In order to deal with use cases that require really small binaries it is
a great idea to put arrow behind a feature flag.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked Jira ticket then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->

# What changes are included in this PR?
See title
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

# Are these changes tested?
Yes
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
  • Loading branch information
iajoiner authored Aug 16, 2024
1 parent afccbe8 commit 3e421d1
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 93 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ jobs:
run: cargo test --all-features
- name: Dry run cargo test (proof-of-sql) (test feature only)
run: cargo test -p proof-of-sql --no-run --no-default-features --features="test"
- name: Dry run cargo test (proof-of-sql) (arrow feature only)
run: cargo test -p proof-of-sql --no-run --no-default-features --features="arrow"
- name: Dry run cargo test (proof-of-sql) (blitzar feature only)
run: cargo test -p proof-of-sql --no-run --no-default-features --features="blitzar"
- name: Dry run cargo test (proof-of-sql) (no features)
Expand Down
1 change: 0 additions & 1 deletion crates/proof-of-sql-parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ doctest = true
test = true

[dependencies]
arrow = { workspace = true }
arrayvec = { workspace = true, features = ["serde"] }
bigdecimal = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
Expand Down
23 changes: 0 additions & 23 deletions crates/proof-of-sql-parser/src/posql_time/unit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::PoSQLTimestampError;
use arrow::datatypes::TimeUnit as ArrowTimeUnit;
use core::fmt;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -29,28 +28,6 @@ impl TryFrom<&str> for PoSQLTimeUnit {
}
}

impl From<PoSQLTimeUnit> for ArrowTimeUnit {
fn from(unit: PoSQLTimeUnit) -> Self {
match unit {
PoSQLTimeUnit::Second => ArrowTimeUnit::Second,
PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
}
}
}

impl From<ArrowTimeUnit> for PoSQLTimeUnit {
fn from(unit: ArrowTimeUnit) -> Self {
match unit {
ArrowTimeUnit::Second => PoSQLTimeUnit::Second,
ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond,
ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond,
ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond,
}
}
}

impl fmt::Display for PoSQLTimeUnit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down
10 changes: 7 additions & 3 deletions crates/proof-of-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ark-ff = { workspace = true }
ark-poly = { workspace = true }
ark-serialize = { workspace = true }
ark-std = { workspace = true }
arrow = { workspace = true }
arrow = { workspace = true, optional = true }
bit-iter = { workspace = true }
bigdecimal = { workspace = true }
blake3 = { workspace = true }
Expand Down Expand Up @@ -63,8 +63,12 @@ tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true }
flexbuffers = { workspace = true }

[package.metadata.cargo-udeps.ignore]
development = ["arrow-csv"]

[features]
default = ["blitzar"]
default = ["arrow", "blitzar"]
arrow = ["dep:arrow"]
test = ["dep:rand"]

[lints]
Expand All @@ -76,7 +80,7 @@ required-features = [ "blitzar", "test" ]

[[example]]
name = "posql_db"
required-features = [ "blitzar" ]
required-features = [ "arrow", "blitzar" ]

[[bench]]
name = "criterion_benches"
Expand Down
8 changes: 4 additions & 4 deletions crates/proof-of-sql/examples/posql_db/run_example.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cd crates/proof-of-sql/examples/posql_db
cargo run --example posql_db create -t sxt.table -c a,b -d BIGINT,VARCHAR
cargo run --example posql_db append -t sxt.table -f hello_world.csv
cargo run --example posql_db prove -q "SELECT b FROM sxt.table WHERE a = 2" -f hello.proof
cargo run --example posql_db verify -q "SELECT b FROM sxt.table WHERE a = 2" -f hello.proof
cargo run --features="arrow blitzar" --example posql_db create -t sxt.table -c a,b -d BIGINT,VARCHAR
cargo run --features="arrow blitzar" --example posql_db append -t sxt.table -f hello_world.csv
cargo run --features="arrow blitzar" --example posql_db prove -q "SELECT b FROM sxt.table WHERE a = 2" -f hello.proof
cargo run --features="arrow blitzar" --example posql_db verify -q "SELECT b FROM sxt.table WHERE a = 2" -f hello.proof
16 changes: 11 additions & 5 deletions crates/proof-of-sql/src/base/commitment/table_commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use super::{
committable_column::CommittableColumn, AppendColumnCommitmentsError, ColumnCommitments,
ColumnCommitmentsMismatch, Commitment, DuplicateIdentifiers,
};
#[cfg(feature = "arrow")]
use crate::base::database::{ArrayRefExt, ArrowArrayToColumnConversionError};
use crate::base::{
database::{
ArrayRefExt, ArrowArrayToColumnConversionError, Column, ColumnField, CommitmentAccessor,
OwnedTable, TableRef,
},
database::{Column, ColumnField, CommitmentAccessor, OwnedTable, TableRef},
scalar::Scalar,
};
#[cfg(feature = "arrow")]
use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use proof_of_sql_parser::{Identifier, ParseError};
Expand Down Expand Up @@ -63,6 +63,7 @@ pub enum TableCommitmentArithmeticError {
}

/// Errors that can occur when trying to create or extend a [`TableCommitment`] from a record batch.
#[cfg(feature = "arrow")]
#[derive(Debug, Error)]
pub enum RecordBatchToColumnsError {
/// Error converting from arrow array
Expand All @@ -74,6 +75,7 @@ pub enum RecordBatchToColumnsError {
}

/// Errors that can occur when attempting to append a record batch to a [`TableCommitment`].
#[cfg(feature = "arrow")]
#[derive(Debug, Error)]
pub enum AppendRecordBatchTableCommitmentError {
/// During commitment operation, metadata indicates that operand tables cannot be the same.
Expand Down Expand Up @@ -354,6 +356,7 @@ impl<C: Commitment> TableCommitment<C> {
/// The row offset is assumed to be the end of the [`TableCommitment`]'s current range.
///
/// Will error on a variety of mismatches, or if the provided columns have mixed length.
#[cfg(feature = "arrow")]
pub fn try_append_record_batch(
&mut self,
batch: &RecordBatch,
Expand All @@ -380,6 +383,7 @@ impl<C: Commitment> TableCommitment<C> {
}
}
/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`].
#[cfg(feature = "arrow")]
pub fn try_from_record_batch(
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
Expand All @@ -388,6 +392,7 @@ impl<C: Commitment> TableCommitment<C> {
}

/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset.
#[cfg(feature = "arrow")]
pub fn try_from_record_batch_with_offset(
batch: &RecordBatch,
offset: usize,
Expand All @@ -411,6 +416,7 @@ impl<C: Commitment> TableCommitment<C> {
}
}

#[cfg(feature = "arrow")]
fn batch_to_columns<'a, S: Scalar + 'a>(
batch: &'a RecordBatch,
alloc: &'a Bump,
Expand Down Expand Up @@ -446,7 +452,7 @@ fn num_rows_of_columns<'a>(
Ok(num_rows)
}

#[cfg(all(test, feature = "blitzar"))]
#[cfg(all(test, feature = "arrow, blitzar"))]
mod tests {
use super::*;
use crate::{
Expand Down
34 changes: 26 additions & 8 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::base::{
math::decimal::{scale_scalar, Precision},
scalar::Scalar,
};
#[cfg(feature = "arrow")]
use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
use bumpalo::Bump;
use proof_of_sql_parser::{
Expand Down Expand Up @@ -350,6 +351,7 @@ impl ColumnType {
}

/// Convert ColumnType values to some arrow DataType
#[cfg(feature = "arrow")]
impl From<&ColumnType> for DataType {
fn from(column_type: &ColumnType) -> Self {
match column_type {
Expand All @@ -363,15 +365,22 @@ impl From<&ColumnType> for DataType {
}
ColumnType::VarChar => DataType::Utf8,
ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"),
ColumnType::TimestampTZ(timeunit, timezone) => DataType::Timestamp(
ArrowTimeUnit::from(*timeunit),
Some(Arc::from(timezone.to_string())),
),
ColumnType::TimestampTZ(timeunit, timezone) => {
let arrow_timezone = Some(Arc::from(timezone.to_string()));
let arrow_timeunit = match timeunit {
PoSQLTimeUnit::Second => ArrowTimeUnit::Second,
PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
};
DataType::Timestamp(arrow_timeunit, arrow_timezone)
}
}
}
}

/// Convert arrow DataType values to some ColumnType
#[cfg(feature = "arrow")]
impl TryFrom<DataType> for ColumnType {
type Error = String;

Expand All @@ -385,10 +394,18 @@ impl TryFrom<DataType> for ColumnType {
DataType::Decimal256(precision, scale) if precision <= 75 => {
Ok(ColumnType::Decimal75(Precision::new(precision)?, scale))
}
DataType::Timestamp(time_unit, timezone_option) => Ok(ColumnType::TimestampTZ(
PoSQLTimeUnit::from(time_unit),
PoSQLTimeZone::try_from(&timezone_option)?,
)),
DataType::Timestamp(time_unit, timezone_option) => {
let posql_time_unit = match time_unit {
ArrowTimeUnit::Second => PoSQLTimeUnit::Second,
ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond,
ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond,
ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond,
};
Ok(ColumnType::TimestampTZ(
posql_time_unit,
PoSQLTimeZone::try_from(&timezone_option)?,
))
}
DataType::Utf8 => Ok(ColumnType::VarChar),
_ => Err(format!("Unsupported arrow data type {:?}", data_type)),
}
Expand Down Expand Up @@ -482,6 +499,7 @@ impl ColumnField {
}

/// Convert ColumnField values to arrow Field
#[cfg(feature = "arrow")]
impl From<&ColumnField> for Field {
fn from(column_field: &ColumnField) -> Self {
Field::new(
Expand Down
13 changes: 10 additions & 3 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ pub use literal_value::LiteralValue;
mod table_ref;
pub use table_ref::TableRef;

#[cfg(feature = "arrow")]
mod arrow_array_to_column_conversion;
#[cfg(feature = "arrow")]
pub use arrow_array_to_column_conversion::{ArrayRefExt, ArrowArrayToColumnConversionError};

#[cfg(feature = "arrow")]
mod record_batch_utility;
#[cfg(feature = "arrow")]
pub use record_batch_utility::ToArrow;

#[cfg(any(test, feature = "test"))]
#[cfg(all(test, feature = "arrow, test"))]
mod test_accessor_utility;
#[cfg(any(test, feature = "test"))]
#[cfg(all(test, feature = "arrow, test"))]
pub use test_accessor_utility::{make_random_test_accessor_data, RandomTestAccessorDescriptor};

mod owned_column;
Expand All @@ -54,9 +58,11 @@ mod expression_evaluation_error;
mod expression_evaluation_test;
pub use expression_evaluation_error::{ExpressionEvaluationError, ExpressionEvaluationResult};

#[cfg(feature = "arrow")]
mod owned_and_arrow_conversions;
#[cfg(feature = "arrow")]
pub use owned_and_arrow_conversions::OwnedArrowConversionError;
#[cfg(test)]
#[cfg(all(test, feature = "arrow"))]
mod owned_and_arrow_conversions_test;

#[cfg(any(test, feature = "test"))]
Expand All @@ -78,6 +84,7 @@ pub use owned_table_test_accessor::OwnedTableTestAccessor;
#[cfg(all(test, feature = "blitzar"))]
mod owned_table_test_accessor_test;
/// Contains traits for scalar <-> i256 conversions
#[cfg(feature = "arrow")]
pub mod scalar_and_i256_conversions;

pub(crate) mod filter_util;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ impl Default for RandomTestAccessorDescriptor {
}

/// Generate a DataFrame with random data
#[allow(dead_code)]
pub fn make_random_test_accessor_data(
rng: &mut StdRng,
cols: &[(&str, ColumnType)],
Expand Down
20 changes: 6 additions & 14 deletions crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ use crate::{
},
},
};
use arrow::datatypes::{Field, Schema};
use blitzar::proof::InnerProductProof;
use bumpalo::Bump;
use curve25519_dalek::RistrettoPoint;
use indexmap::{IndexMap, IndexSet};
use proof_of_sql_parser::{Identifier, ResourceId};
use std::sync::Arc;

#[test]
fn we_can_correctly_fetch_the_query_result_schema() {
Expand Down Expand Up @@ -60,19 +58,13 @@ fn we_can_correctly_fetch_the_query_result_schema() {
.unwrap(),
);

let column_fields: Vec<Field> = provable_ast
.get_column_result_fields()
.iter()
.map(|v| v.into())
.collect();
let schema = Arc::new(Schema::new(column_fields));

let column_fields: Vec<ColumnField> = provable_ast.get_column_result_fields();
assert_eq!(
schema,
Arc::new(Schema::new(vec![
Field::new("a", (&ColumnType::BigInt).into(), false,),
Field::new("b", (&ColumnType::BigInt).into(), false,)
]))
column_fields,
vec![
ColumnField::new("a".parse().unwrap(), ColumnType::BigInt),
ColumnField::new("b".parse().unwrap(), ColumnType::BigInt)
]
);
}

Expand Down
21 changes: 6 additions & 15 deletions crates/proof-of-sql/src/sql/ast/filter_expr_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ use crate::{
},
},
};
use arrow::datatypes::{Field, Schema};
use blitzar::proof::InnerProductProof;
use bumpalo::Bump;
use curve25519_dalek::RistrettoPoint;
use indexmap::{IndexMap, IndexSet};
use proof_of_sql_parser::{Identifier, ResourceId};
use std::sync::Arc;

#[test]
fn we_can_correctly_fetch_the_query_result_schema() {
Expand Down Expand Up @@ -53,20 +51,13 @@ fn we_can_correctly_fetch_the_query_result_schema() {
)
.unwrap(),
);

let column_fields: Vec<Field> = provable_ast
.get_column_result_fields()
.iter()
.map(|v| v.into())
.collect();
let schema = Arc::new(Schema::new(column_fields));

let column_fields: Vec<ColumnField> = provable_ast.get_column_result_fields();
assert_eq!(
schema,
Arc::new(Schema::new(vec![
Field::new("a", (&ColumnType::BigInt).into(), false,),
Field::new("b", (&ColumnType::BigInt).into(), false,)
]))
column_fields,
vec![
ColumnField::new("a".parse().unwrap(), ColumnType::BigInt),
ColumnField::new("b".parse().unwrap(), ColumnType::BigInt)
]
);
}

Expand Down
Loading

0 comments on commit 3e421d1

Please sign in to comment.