Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: inequality expr redefinition #465

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ bit-iter = { version = "1.1.1" }
bigdecimal = { version = "0.4.5", default-features = false, features = ["serde"] }
blake3 = { version = "1.3.3", default-features = false }
blitzar = { version = "4.0.0" }
bnum = { version = "0.3.0" }
bumpalo = { version = "3.11.0" }
bytemuck = {version = "1.16.3", features = ["derive"]}
byte-slice-cast = { version = "1.2.1", default-features = false }
Expand Down
8 changes: 4 additions & 4 deletions crates/proof-of-sql-parser/src/intermediate_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ pub enum BinaryOperator {
/// Comparison =
Equal,

/// Comparison <=
LessThanOrEqual,
/// Comparison <
LessThan,

/// Comparison >=
GreaterThanOrEqual,
/// Comparison >
GreaterThan,
}

/// Possible unary operators for simple expressions
Expand Down
4 changes: 2 additions & 2 deletions crates/proof-of-sql-parser/src/intermediate_ast_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ fn we_can_parse_a_query_with_filter_lt() {
query(
cols_res(&["a"]),
tab(None, "tab"),
not(ge(col("b"), lit(4))),
lt(col("b"), lit(4)),
vec![],
),
vec![],
Expand Down Expand Up @@ -1023,7 +1023,7 @@ fn we_can_parse_a_query_with_filter_gt() {
query(
cols_res(&["a"]),
tab(None, "tab"),
not(le(col("b"), lit(4))),
gt(col("b"), lit(4)),
vec![],
),
vec![],
Expand Down
16 changes: 8 additions & 8 deletions crates/proof-of-sql-parser/src/sql.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -238,35 +238,35 @@ Expression: Box<intermediate_ast::Expression> = {
}),

#[precedence(level="4")] #[assoc(side="left")]
<left: Expression> ">=" <right: Expression> =>
<left: Expression> ">" <right: Expression> =>
Box::new(intermediate_ast::Expression::Binary {
op: intermediate_ast::BinaryOperator::GreaterThanOrEqual,
op: intermediate_ast::BinaryOperator::GreaterThan,
left,
right,
}),

<left: Expression> "<=" <right: Expression> =>
<left: Expression> "<" <right: Expression> =>
Box::new(intermediate_ast::Expression::Binary {
op: intermediate_ast::BinaryOperator::LessThanOrEqual,
op: intermediate_ast::BinaryOperator::LessThan,
left,
right,
}),

<left: Expression> ">" <right: Expression> =>
<left: Expression> ">=" <right: Expression> =>
Box::new(intermediate_ast::Expression::Unary {
op: intermediate_ast::UnaryOperator::Not,
expr: Box::new(intermediate_ast::Expression::Binary {
op: intermediate_ast::BinaryOperator::LessThanOrEqual,
op: intermediate_ast::BinaryOperator::LessThan,
left,
right,
}),
}),

<left: Expression> "<" <right: Expression> =>
<left: Expression> "<=" <right: Expression> =>
Box::new(intermediate_ast::Expression::Unary {
op: intermediate_ast::UnaryOperator::Not,
expr: Box::new(intermediate_ast::Expression::Binary {
op: intermediate_ast::BinaryOperator::GreaterThanOrEqual,
op: intermediate_ast::BinaryOperator::GreaterThan,
left,
right,
}),
Expand Down
6 changes: 3 additions & 3 deletions crates/proof-of-sql-parser/src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ impl From<PoSqlBinaryOperator> for BinaryOperator {
PoSqlBinaryOperator::And => BinaryOperator::And,
PoSqlBinaryOperator::Or => BinaryOperator::Or,
PoSqlBinaryOperator::Equal => BinaryOperator::Eq,
PoSqlBinaryOperator::LessThanOrEqual => BinaryOperator::LtEq,
PoSqlBinaryOperator::GreaterThanOrEqual => BinaryOperator::GtEq,
PoSqlBinaryOperator::LessThan => BinaryOperator::Lt,
PoSqlBinaryOperator::GreaterThan => BinaryOperator::Gt,
PoSqlBinaryOperator::Add => BinaryOperator::Plus,
PoSqlBinaryOperator::Subtract => BinaryOperator::Minus,
PoSqlBinaryOperator::Multiply => BinaryOperator::Multiply,
Expand Down Expand Up @@ -291,7 +291,7 @@ mod test {
);
check_posql_intermediate_ast_to_sqlparser_equality("select 1 as a, 'Meow' as d, b as b from namespace.table where c = 4 order by a desc limit 10 offset 0;");
check_posql_intermediate_ast_to_sqlparser_equality(
"select true as cons, a and b or c >= 4 as comp from tab where d = 'Space and Time';",
"select true as cons, a and b or c > 4 as comp from tab where d = 'Space and Time';",
);
check_posql_intermediate_ast_to_sqlparser_equality(
"select cat as cat, true as cons, max(meow) as max_meow from tab where d = 'Space and Time' group by cat;",
Expand Down
24 changes: 22 additions & 2 deletions crates/proof-of-sql-parser/src/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@ pub fn equal(left: Box<Expression>, right: Box<Expression>) -> Box<Expression> {
/// Construct a new boxed `Expression` A >= B
#[must_use]
pub fn ge(left: Box<Expression>, right: Box<Expression>) -> Box<Expression> {
not(Box::new(Expression::Binary {
op: BinaryOperator::LessThan,
left,
right,
}))
}

/// Construct a new boxed `Expression` A > B
#[must_use]
pub fn gt(left: Box<Expression>, right: Box<Expression>) -> Box<Expression> {
Box::new(Expression::Binary {
op: BinaryOperator::GreaterThanOrEqual,
op: BinaryOperator::GreaterThan,
left,
right,
})
Expand All @@ -40,8 +50,18 @@ pub fn ge(left: Box<Expression>, right: Box<Expression>) -> Box<Expression> {
/// Construct a new boxed `Expression` A <= B
#[must_use]
pub fn le(left: Box<Expression>, right: Box<Expression>) -> Box<Expression> {
not(Box::new(Expression::Binary {
op: BinaryOperator::GreaterThan,
left,
right,
}))
}

/// Construct a new boxed `Expression` A < B
#[must_use]
pub fn lt(left: Box<Expression>, right: Box<Expression>) -> Box<Expression> {
Box::new(Expression::Binary {
op: BinaryOperator::LessThanOrEqual,
op: BinaryOperator::LessThan,
left,
right,
})
Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ bit-iter = { workspace = true }
bigdecimal = { workspace = true }
blake3 = { workspace = true }
blitzar = { workspace = true, optional = true }
bnum = { workspace = true }
bumpalo = { workspace = true, features = ["collections"] }
bytemuck = { workspace = true }
byte-slice-cast = { workspace = true }
Expand Down
8 changes: 0 additions & 8 deletions crates/proof-of-sql/src/base/bit/abs_bit_mask.rs

This file was deleted.

193 changes: 71 additions & 122 deletions crates/proof-of-sql/src/base/bit/bit_distribution.rs
Original file line number Diff line number Diff line change
@@ -1,70 +1,87 @@
use crate::base::{bit::make_abs_bit_mask, scalar::Scalar};
use super::bit_mask_utils::{is_bit_mask_negative_representation, make_bit_mask};
use crate::base::scalar::{Scalar, ScalarExt};
use ark_std::iterable::Iterable;
use bit_iter::BitIter;
use bnum::types::U256;
use core::convert::Into;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

/// Describe the distribution of bit values in a table column
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct BitDistribution {
/// We use two arrays to track which bits vary
/// and the constant bit values. If
/// `{x_1, ..., x_n}` represents the values [`BitDistribution`] describes, then:
/// `or_all = abs(x_1) | abs(x_2) | ... | abs(x_n)`
/// `vary_mask & (1 << i) =`
/// `1` if `x_s & (1 << i) != x_t & (1 << i)` for some `s != t`
/// 0 otherwise
pub or_all: [u64; 4],
pub vary_mask: [u64; 4],
/// Identifies all columns that are identical to the leading column (the sign column). The lead bit indicates if the sign column is constant
pub(crate) vary_mask: [u64; 4],
/// Identifies all columns that are the identical to the lead column. The lead bit indicates the sign of the last row of data (only relevant if the sign is constant)
pub(crate) leading_bit_mask: [u64; 4],
}

impl BitDistribution {
pub fn new<S: Scalar, T: Into<S> + Clone>(data: &[T]) -> Self {
if data.is_empty() {
return Self {
or_all: [0; 4],
vary_mask: [0; 4],
};
}
let mut or_all = make_abs_bit_mask(data[0].clone().into());
let mut vary_mask = [0; 4];
for x in data.iter().skip(1) {
let mask = make_abs_bit_mask((*x).clone().into());
for i in 0..4 {
vary_mask[i] |= or_all[i] ^ mask[i];
or_all[i] |= mask[i];
}
let bit_masks = data.iter().cloned().map(Into::<S>::into).map(make_bit_mask);
let (sign_mask, inverse_sign_mask) =
bit_masks
.clone()
.fold((U256::MAX, U256::MAX), |acc, bit_mask| {
let bit_mask = if is_bit_mask_negative_representation(bit_mask) {
bit_mask ^ (U256::MAX >> 1)
} else {
bit_mask
};
(acc.0 & bit_mask, acc.1 & !bit_mask)
});
let vary_mask_bit = U256::from(
!bit_masks
.map(is_bit_mask_negative_representation)
.all_equal(),
) << 255;
let vary_mask: U256 = !(sign_mask | inverse_sign_mask) | vary_mask_bit;

Self {
leading_bit_mask: sign_mask.into(),
vary_mask: vary_mask.into(),
}
Self { or_all, vary_mask }
}

pub fn num_varying_bits(&self) -> usize {
let mut res = 0_usize;
for xi in &self.vary_mask {
res += xi.count_ones() as usize;
}
res
pub fn vary_mask(&self) -> U256 {
U256::from(self.vary_mask)
}

/// Identifies all columns that are the identical to the lead column.
pub fn leading_bit_mask(&self) -> U256 {
U256::from(self.leading_bit_mask) | (U256::ONE << 255)
}

/// Identifies all columns that are the identical to the inverse of the lead column.
pub fn leading_bit_inverse_mask(&self) -> U256 {
(!self.vary_mask() ^ self.leading_bit_mask()) & (U256::MAX >> 1)
}

pub fn has_varying_sign_bit(&self) -> bool {
self.vary_mask[3] & (1 << 63) != 0
/// # Panics
///
/// Panics if conversion from `ExpType` to `usize` fails
pub fn num_varying_bits(&self) -> usize {
self.vary_mask().count_ones() as usize
}

#[allow(clippy::missing_panics_doc)]
pub fn sign_bit(&self) -> bool {
assert!(!self.has_varying_sign_bit());
self.or_all[3] & (1 << 63) != 0
/// # Panics
///
/// Panics if lead bit varies but `bit_evals` is empty
pub fn leading_bit_eval<S: ScalarExt>(&self, bit_evals: &[S], one_eval: S) -> S {
if U256::from(self.vary_mask) & (U256::ONE << 255) != U256::ZERO {
*bit_evals.last().expect("bit_evals should be non-empty")
} else if U256::from(self.leading_bit_mask) & (U256::ONE << 255) == U256::ZERO {
S::ZERO
} else {
one_eval
}
}

/// Check if this instance represents a valid bit distribution. `is_valid`
/// can be used after deserializing a [`BitDistribution`] from an untrusted
/// source.
pub fn is_valid(&self) -> bool {
for (m, o) in self.vary_mask.iter().zip(self.or_all) {
if m & !o != 0 {
return false;
}
}
true
(self.vary_mask() & self.leading_bit_mask()) & (U256::MAX >> 1) == U256::ZERO
}

/// In order to avoid cases with large numbers where there can be both a positive and negative
Expand All @@ -73,92 +90,24 @@ impl BitDistribution {
/// Currently this is set to be the minimal value that will include the sum of two signed 128-bit
/// integers. The range will likely be expanded in the future as we support additional expressions.
pub fn is_within_acceptable_range(&self) -> bool {
// handle the case of everything zero
if self.num_varying_bits() == 0 && self.constant_part() == [0; 4] {
return true;
}

// signed 128 bit numbers range from
// -2^127 to 2^127-1
// the maximum absolute value of the sum of two signed 128-integers is
// then
// 2 * (2^127) = 2^128
self.most_significant_abs_bit() <= 128
}

/// If `{b_i}` represents the non-varying 1-bits of the absolute values, return the value
/// `sum_i b_i 2 ^ i`
pub fn constant_part(&self) -> [u64; 4] {
let mut val = [0; 4];
self.for_each_abs_constant_bit(|i: usize, bit: usize| {
val[i] |= 1u64 << bit;
});
val
}

/// Iterate over each constant 1-bit for the absolute values
pub fn for_each_abs_constant_bit<F>(&self, mut f: F)
where
F: FnMut(usize, usize),
{
for i in 0..4 {
let bitset = if i == 3 {
!(self.vary_mask[i] | (1 << 63))
} else {
!self.vary_mask[i]
};
let bitset = bitset & self.or_all[i];
for pos in BitIter::from(bitset) {
f(i, pos);
}
}
}

/// Iterate over each varying bit for the absolute values
pub fn for_each_abs_varying_bit<F>(&self, mut f: F)
where
F: FnMut(usize, usize),
{
for i in 0..4 {
let bitset = if i == 3 {
self.vary_mask[i] & !(1 << 63)
} else {
self.vary_mask[i]
};
for pos in BitIter::from(bitset) {
f(i, pos);
}
}
}

/// Iterate over each varying bit for the absolute values and the sign bit
/// if it varies
pub fn for_each_varying_bit<F>(&self, mut f: F)
where
F: FnMut(usize, usize),
{
for i in 0..4 {
let bitset = self.vary_mask[i];
for pos in BitIter::from(bitset) {
f(i, pos);
}
}
(self.leading_bit_inverse_mask() >> 128) == (U256::MAX >> 129)
}

/// Return the position of the most significant bit of the absolute values
/// Iterate over each varying bit
///
/// # Panics
/// Panics if no bits are set to 1 in the bit representation of `or_all`.
pub fn most_significant_abs_bit(&self) -> usize {
let mask = self.or_all[3] & !(1 << 63);
if mask != 0 {
return 64 - (mask.leading_zeros() as usize) - 1 + 3 * 64;
}
for i in (0..3).rev() {
let mask = self.or_all[i];
if mask != 0 {
return 64 - (mask.leading_zeros() as usize) - 1 + 64 * i;
}
}
panic!("no bits are set");
///
/// The panic shouldn't be mathematically possible
pub fn vary_mask_iter(&self) -> impl Iterator<Item = u8> + '_ {
(0..4).flat_map(|i| {
BitIter::from(self.vary_mask[i])
.iter()
.map(move |pos| u8::try_from(i * 64 + pos).expect("index greater than 255"))
})
}
}
Loading
Loading