Skip to content

Commit

Permalink
Added total_ and claimed_sums as formal variables in ExprEvaluator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 19, 2024
1 parent 77de4d7 commit 57cde32
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
65 changes: 49 additions & 16 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -45,7 +44,12 @@ impl Expr {
idx,
offset,
}) => {
format!("col_{interaction}_{idx}[{offset}]")
let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET.try_into().unwrap() {
"claimed_sum_offset".to_string()
} else {
offset.to_string()
};
format!("col_{interaction}_{idx}[{offset_str}]")
}
Expr::SecureCol([a, b, c, d]) => format!(
"SecureCol({}, {}, {}, {})",
Expand Down Expand Up @@ -197,20 +201,55 @@ fn combine_formal<R: Relation<Expr, Expr>>(relation: &R, values: &[Expr]) -> Exp
- z
}

pub struct FormalLogupAtRow {
pub interaction: usize,
pub total_sum: Expr,
pub claimed_sum: Option<(Expr, usize)>,
pub prev_col_cumsum: Expr,
pub cur_frac: Option<Fraction<Expr, Expr>>,
pub is_finalized: bool,
pub is_first: Expr,
pub log_size: u32,
}

// (1 << 31) - 1 is an offset no column can reach, it signifies the variable
// offset, which is an input to the verifier.
const CLAIMED_SUM_DUMMY_OFFSET: usize = (1 << 31) - 1;

impl FormalLogupAtRow {
pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self {
let total_sum_name = "total_sum".to_string();
let claimed_sum_name = "claimed_sum".to_string();

Self {
interaction,
// TODO(alont): Should these be Expr::SecureField?
total_sum: Expr::Param(total_sum_name),
claimed_sum: has_partial_sum
.then_some((Expr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)),
prev_col_cumsum: Expr::zero(),
cur_frac: None,
is_finalized: true,
is_first: Expr::zero(),
log_size,
}
}
}

/// An Evaluator that saves all constraint expressions.
pub struct ExprEvaluator {
pub cur_var_index: usize,
pub constraints: Vec<Expr>,
pub logup: LogupAtRow<Self>,
pub logup: FormalLogupAtRow,
}

impl ExprEvaluator {
#[allow(dead_code)]
pub fn new(log_size: u32, logup_sums: LogupSums) -> Self {
pub fn new(log_size: u32, has_partial_sum: bool) -> Self {
Self {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size),
}
}
}
Expand Down Expand Up @@ -276,20 +315,19 @@ impl EvalAtRow for ExprEvaluator {

#[cfg(test)]
mod tests {
use num_traits::{One, Zero};
use num_traits::One;

use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkEval, RelationEntry, ORIGINAL_TRACE_IDX,
};
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;

#[test]
fn test_expr_eval() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, (SecureField::zero(), None)));
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
assert_eq!(eval.constraints.len(), 2);
assert_eq!(
eval.constraints[0],
Expand Down Expand Up @@ -390,12 +428,7 @@ mod tests {
idx: 3,
offset: 0
})),
Box::new(Expr::SecureCol([
Box::new(Expr::Const(M31(0))),
Box::new(Expr::Const(M31(0))),
Box::new(Expr::Const(M31(0))),
Box::new(Expr::Const(M31(0)))
]))
Box::new(Expr::Param("total_sum".into()))
))
))
)),
Expand Down Expand Up @@ -447,7 +480,7 @@ mod tests {
#[test]
fn test_format_expr() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, (SecureField::zero(), None)));
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let constraint0_str = "(1) * ((((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1/(col_1_0[0] + col_1_1[0])))";
assert_eq!(eval.constraints[0].format_expr(), constraint0_str);
let constraint1_str = "(1) \
Expand All @@ -457,7 +490,7 @@ mod tests {
col_2_7[-1], \
col_2_9[-1], \
col_2_11[-1]\
) - ((col_0_3[0]) * (SecureCol(0, 0, 0, 0)))) \
) - ((col_0_3[0]) * (total_sum))) \
- (0)) \
* (0 + (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ macro_rules! logup_proxy {

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted
// offset from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum {
let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() {
Some((claimed_sum, claimed_row_index)) => {
let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = self
.next_extension_interaction_mask(
Expand All @@ -194,7 +194,7 @@ macro_rules! logup_proxy {
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum =
prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum;
prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone();

self.add_constraint(diff * frac.denominator - frac.numerator);
Expand Down
15 changes: 8 additions & 7 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,21 +292,22 @@ mod tests {
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
);

let eval = component.evaluate(ExprEvaluator::new(
log_n_rows,
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
));
let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));

assert_eq!(eval.constraints.len(), 2);
let constraint0_str = "(1) \
* ((SecureCol(col_2_5[255], col_2_8[255], col_2_11[255], col_2_14[255]) \
- (SecureCol(223732908, 22408442, 1020999916, 2109866192))) \
* ((SecureCol(\
col_2_5[claimed_sum_offset], \
col_2_8[claimed_sum_offset], \
col_2_11[claimed_sum_offset], \
col_2_14[claimed_sum_offset]\
) - (claimed_sum)) \
* (col_0_2[0]))";
assert_eq!(eval.constraints[0].format_expr(), constraint0_str);
let constraint1_str = "(1) \
* ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \
- (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \
- ((col_0_2[0]) * (SecureCol(223732908, 22408442, 1020999916, 2109866192)))) \
- ((col_0_2[0]) * (total_sum))) \
- (0)) \
* ((0 \
+ (StateMachineElements_alpha0) * (col_1_0[0]) \
Expand Down

0 comments on commit 57cde32

Please sign in to comment.