Skip to content

Commit

Permalink
feat: allow FirstRoundBuilder to produce MLEs
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 10, 2024
1 parent cb4a93c commit edc345d
Show file tree
Hide file tree
Showing 16 changed files with 211 additions and 46 deletions.
69 changes: 65 additions & 4 deletions crates/proof-of-sql/src/sql/proof/first_round_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use alloc::vec::Vec;
use crate::base::{
commitment::{Commitment, CommittableColumn, VecCommitmentExt},
polynomial::MultilinearExtension,
scalar::Scalar,
};
use alloc::{boxed::Box, vec::Vec};
/// Track the result created by a query
pub struct FirstRoundBuilder {
pub struct FirstRoundBuilder<'a, S> {
commitment_descriptor: Vec<CommittableColumn<'a>>,
pcs_proof_mles: Vec<Box<dyn MultilinearExtension<S> + 'a>>,
/// The number of challenges used in the proof.
/// Specifically, these are the challenges that the verifier sends to
/// the prover after the prover sends the result, but before the prover
Expand All @@ -10,15 +17,17 @@ pub struct FirstRoundBuilder {
one_evaluation_lengths: Vec<usize>,
}

impl Default for FirstRoundBuilder {
impl<'a, S: Scalar> Default for FirstRoundBuilder<'a, S> {
fn default() -> Self {
Self::new()
}
}

impl FirstRoundBuilder {
impl<'a, S: Scalar> FirstRoundBuilder<'a, S> {
pub fn new() -> Self {
Self {
commitment_descriptor: Vec::new(),
pcs_proof_mles: Vec::new(),
num_post_result_challenges: 0,
one_evaluation_lengths: Vec::new(),
}
Expand All @@ -34,6 +43,58 @@ impl FirstRoundBuilder {
self.one_evaluation_lengths.push(length);
}

/// Produce an anchored MLE that we can reference in sumcheck.
///
/// An anchored MLE is an MLE where the verifier has access to the commitment.
pub fn produce_anchored_mle(&mut self, data: impl MultilinearExtension<S> + 'a) {
self.pcs_proof_mles.push(Box::new(data));
}

/// Produce an MLE for a intermediate computed column that we can reference in sumcheck.
///
/// Because the verifier doesn't have access to the MLE's commitment, we will need to
/// commit to the MLE before we form the sumcheck polynomial.
pub fn produce_intermediate_mle(
&mut self,
data: impl MultilinearExtension<S> + Into<CommittableColumn<'a>> + Copy + 'a,
) {
self.commitment_descriptor.push(data.into());
self.produce_anchored_mle(data);
}

/// Compute commitments of all the interemdiate MLEs used in sumcheck
#[tracing::instrument(
name = "FinalRoundBuilder::commit_intermediate_mles",
level = "debug",
skip_all
)]
pub fn commit_intermediate_mles<C: Commitment>(
&self,
offset_generators: usize,
setup: &C::PublicSetup<'_>,
) -> Vec<C> {
Vec::from_commitable_columns_with_offset(
&self.commitment_descriptor,
offset_generators,
setup,
)
}

/// Given the evaluation vector, compute evaluations of all the MLEs used in sumcheck except
/// for those that correspond to result columns sent to the verifier.
#[tracing::instrument(
name = "FinalRoundBuilder::evaluate_pcs_proof_mles",
level = "debug",
skip_all
)]
pub fn evaluate_pcs_proof_mles(&self, evaluation_vec: &[S]) -> Vec<S> {
let mut res = Vec::with_capacity(self.pcs_proof_mles.len());
for evaluator in &self.pcs_proof_mles {
res.push(evaluator.inner_product(evaluation_vec));
}
res
}

/// The number of challenges used in the proof.
/// Specifically, these are the challenges that the verifier sends to
/// the prover after the prover sends the result, but before the prover
Expand Down
73 changes: 73 additions & 0 deletions crates/proof-of-sql/src/sql/proof/first_round_builder_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use super::FirstRoundBuilder;
use crate::base::{
commitment::{Commitment, CommittableColumn},
scalar::Curve25519Scalar,
};
use curve25519_dalek::RistrettoPoint;

#[test]
fn we_can_compute_commitments_for_intermediate_mles_using_a_zero_offset() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FirstRoundBuilder::<Curve25519Scalar>::new();
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let offset_generators = 0_usize;
let commitments: Vec<RistrettoPoint> = builder.commit_intermediate_mles(offset_generators, &());
assert_eq!(
commitments,
[RistrettoPoint::compute_commitments(
&[CommittableColumn::from(&mle2[..])],
offset_generators,
&()
)[0]]
);
}

#[test]
fn we_can_compute_commitments_for_intermediate_mles_using_a_non_zero_offset() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FirstRoundBuilder::<Curve25519Scalar>::new();
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let offset_generators = 123_usize;
let commitments: Vec<RistrettoPoint> = builder.commit_intermediate_mles(offset_generators, &());
assert_eq!(
commitments,
[RistrettoPoint::compute_commitments(
&[CommittableColumn::from(&mle2[..])],
offset_generators,
&()
)[0]]
);
}

#[test]
fn we_can_evaluate_pcs_proof_mles() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FirstRoundBuilder::<Curve25519Scalar>::new();
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let evaluation_vec = [
Curve25519Scalar::from(100u64),
Curve25519Scalar::from(10u64),
];
let evals = builder.evaluate_pcs_proof_mles(&evaluation_vec);
let expected_evals = [
Curve25519Scalar::from(120u64),
Curve25519Scalar::from(1200u64),
];
assert_eq!(evals, expected_evals);
}

#[test]
fn we_can_add_post_result_challenges() {
let mut builder = FirstRoundBuilder::<Curve25519Scalar>::new();
assert_eq!(builder.num_post_result_challenges(), 0);
builder.request_post_result_challenges(1);
assert_eq!(builder.num_post_result_challenges(), 1);
builder.request_post_result_challenges(2);
assert_eq!(builder.num_post_result_challenges(), 3);
}
2 changes: 2 additions & 0 deletions crates/proof-of-sql/src/sql/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub(crate) use result_element_serialization::{

mod first_round_builder;
pub(crate) use first_round_builder::FirstRoundBuilder;
#[cfg(all(test, feature = "blitzar"))]
mod first_round_builder_test;

#[cfg(all(test, feature = "arrow"))]
mod provable_query_result_test;
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof/proof_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub trait ProverEvaluate {
/// Evaluate the query, modify `FirstRoundBuilder` and return the result.
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
builder: &mut FirstRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S>;
Expand Down
73 changes: 49 additions & 24 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ pub(super) struct QueryProof<CP: CommitmentEvaluationProof> {
pub bit_distributions: Vec<BitDistribution>,
/// One evaluation lengths
pub one_evaluation_lengths: Vec<usize>,
/// Commitments
pub commitments: Vec<CP::Commitment>,
/// First Round Commitments
pub first_round_commitments: Vec<CP::Commitment>,
/// Final Round Commitments
pub final_round_commitments: Vec<CP::Commitment>,
/// Sumcheck Proof
pub sumcheck_proof: SumcheckProof<CP::Scalar>,
/// MLEs used in sumcheck except for the result columns
Expand Down Expand Up @@ -95,7 +97,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.collect();

// Prover First Round: Evaluate the query && get the right number of post result challenges
let mut first_round_builder = FirstRoundBuilder::new();
let mut first_round_builder = FirstRoundBuilder::<CP::Scalar>::new();
let query_result = expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map);
let owned_table_result = OwnedTable::from(&query_result);
let provable_result = query_result.into();
Expand All @@ -111,13 +113,18 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);

// commit to any intermediate MLEs
let first_round_commitments =
first_round_builder.commit_intermediate_mles(min_row_num, setup);

// construct a transcript for the proof
let mut transcript: Keccak256Transcript = make_transcript(
expr,
&owned_table_result,
range_length,
min_row_num,
one_evaluation_lengths,
&first_round_commitments,
);

// These are the challenges that will be consumed by the proof
Expand All @@ -130,34 +137,37 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(first_round_builder.num_post_result_challenges())
.collect();

let mut builder = FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
let mut final_round_builder =
FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);

for col_ref in total_col_refs {
builder.produce_anchored_mle(accessor.get_column(col_ref));
final_round_builder.produce_anchored_mle(accessor.get_column(col_ref));
}

expr.final_round_evaluate(&mut builder, &alloc, &table_map);
expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map);

let num_sumcheck_variables = builder.num_sumcheck_variables();
let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();

// commit to any intermediate MLEs
let commitments = builder.commit_intermediate_mles(min_row_num, setup);
let final_round_commitments =
final_round_builder.commit_intermediate_mles(min_row_num, setup);

// add the commitments, bit distributions and one evaluation lengths to the proof
extend_transcript_with_commitments(
&mut transcript,
&commitments,
builder.bit_distributions(),
&final_round_commitments,
final_round_builder.bit_distributions(),
);

// construct the sumcheck polynomial
let num_random_scalars = num_sumcheck_variables + builder.num_sumcheck_subpolynomials();
let num_random_scalars =
num_sumcheck_variables + final_round_builder.num_sumcheck_subpolynomials();
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let state = make_sumcheck_prover_state(
builder.sumcheck_subpolynomials(),
final_round_builder.sumcheck_subpolynomials(),
num_sumcheck_variables,
&SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
);
Expand All @@ -169,7 +179,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
// evaluate the MLEs used in sumcheck except for the result columns
let mut evaluation_vec = vec![Zero::zero(); range_length];
compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
let pcs_proof_evaluations = builder.evaluate_pcs_proof_mles(&evaluation_vec);
let pcs_proof_evaluations = final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);

// commit to the MLE evaluations
transcript.extend_canonical_serialize_as_le(&pcs_proof_evaluations);
Expand All @@ -181,9 +191,15 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(pcs_proof_evaluations.len())
.collect();

assert_eq!(random_scalars.len(), builder.pcs_proof_mles().len());
assert_eq!(
random_scalars.len(),
final_round_builder.pcs_proof_mles().len()
);
let mut folded_mle = vec![Zero::zero(); range_length];
for (multiplier, evaluator) in random_scalars.iter().zip(builder.pcs_proof_mles().iter()) {
for (multiplier, evaluator) in random_scalars
.iter()
.zip(final_round_builder.pcs_proof_mles().iter())
{
evaluator.mul_add(&mut folded_mle, multiplier);
}

Expand All @@ -197,9 +213,10 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
);

let proof = Self {
bit_distributions: builder.bit_distributions().to_vec(),
bit_distributions: final_round_builder.bit_distributions().to_vec(),
one_evaluation_lengths: one_evaluation_lengths.to_vec(),
commitments,
first_round_commitments,
final_round_commitments,
sumcheck_proof,
pcs_proof_evaluations,
evaluation_proof,
Expand Down Expand Up @@ -253,6 +270,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
self.range_length,
min_row_num,
&self.one_evaluation_lengths,
&self.first_round_commitments,
);

// These are the challenges that will be consumed by the proof
Expand All @@ -268,7 +286,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
// add the commitments and bit disctibutions to the proof
extend_transcript_with_commitments(
&mut transcript,
&self.commitments,
&self.final_round_commitments,
&self.bit_distributions,
);

Expand Down Expand Up @@ -337,7 +355,8 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
let pcs_proof_commitments: Vec<_> = column_references
.iter()
.map(|col| accessor.get_commitment(*col))
.chain(self.commitments.iter().cloned())
.chain(self.first_round_commitments.iter().cloned())
.chain(self.final_round_commitments.iter().cloned())
.collect();
let evaluation_accessor: IndexMap<_, _> = column_references
.into_iter()
Expand Down Expand Up @@ -391,7 +410,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
}

fn validate_sizes(&self, counts: &ProofCounts) -> bool {
self.commitments.len() == counts.intermediate_mles
self.final_round_commitments.len() == counts.intermediate_mles
&& self.pcs_proof_evaluations.len() == counts.intermediate_mles + counts.anchored_mles
}
}
Expand All @@ -410,22 +429,28 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
/// * `min_row_num` - The minimum row number in the index range of the tables referenced by the query.
/// * `one_evaluation_lengths` - The lengths of the one evaluations.
///
/// * `first_round_commitments` - A slice of commitments produced before post-result challenges that are part of the proof.
///
/// # Returns
///
/// A transcript initialized with the provided data.
fn make_transcript<S: Scalar, T: Transcript>(
fn make_transcript<C: Commitment, T: Transcript>(
expr: &(impl ProofPlan + Serialize),
result: &OwnedTable<S>,
result: &OwnedTable<C::Scalar>,
range_length: usize,
min_row_num: usize,
one_evaluation_lengths: &[usize],
first_round_commitments: &[C],
) -> T {
let mut transcript = T::new();
extend_transcript_with_owned_table(&mut transcript, result);
transcript.extend_serialize_as_le(expr);
transcript.extend_serialize_as_le(&range_length);
transcript.extend_serialize_as_le(&min_row_num);
transcript.extend_serialize_as_le(one_evaluation_lengths);
for commitment in first_round_commitments {
commitment.append_to_transcript(&mut transcript);
}
transcript
}

Expand Down Expand Up @@ -481,10 +506,10 @@ fn extend_transcript_with_owned_table<S: Scalar, T: Transcript>(
/// * `bit_distributions` - The bit distributions to add to the transcript.
fn extend_transcript_with_commitments<C: Commitment>(
transcript: &mut impl Transcript,
commitments: &[C],
final_round_commitments: &[C],
bit_distributions: &[BitDistribution],
) {
for commitment in commitments {
for commitment in final_round_commitments {
commitment.append_to_transcript(transcript);
}
transcript.extend_serialize_as_le(bit_distributions);
Expand Down
Loading

0 comments on commit edc345d

Please sign in to comment.