Skip to content

Commit

Permalink
Exported shared code to temp variables in ExprEvaluator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 19, 2024
1 parent 8750fad commit 22d6538
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 211 deletions.
240 changes: 62 additions & 178 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl Expr {
Expr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()),
Expr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()),
Expr::Neg(a) => format!("-({})", a.format_expr()),
Expr::Inv(a) => format!("1/({})", a.format_expr()),
Expr::Inv(a) => format!("1 / ({})", a.format_expr()),
}
}
}
Expand Down Expand Up @@ -241,6 +241,8 @@ pub struct ExprEvaluator {
pub cur_var_index: usize,
pub constraints: Vec<Expr>,
pub logup: FormalLogupAtRow,
pub temp_vars: Vec<(String, Expr)>,
cur_temp_var_index: usize,
}

impl ExprEvaluator {
Expand All @@ -250,8 +252,37 @@ impl ExprEvaluator {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size),
temp_vars: vec![],
cur_temp_var_index: 0,
}
}

pub fn add_temp_var(&mut self, expr: Expr) -> Expr {
let name = format!("temp_{}", self.cur_temp_var_index);
let temp_var = Expr::Param(name.clone());
self.temp_vars.push((name, expr));
self.cur_temp_var_index += 1;
temp_var
}

pub fn format_constraints(&self) -> String {
let lets_string = self
.temp_vars
.iter()
.map(|(name, expr)| format!("let {} = {};", name, expr.format_expr()))
.collect::<Vec<String>>()
.join("\n");

let constraints_str = self
.constraints
.iter()
.enumerate()
.map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";")
.collect::<Vec<String>>()
.join("\n\n");

lets_string + "\n\n" + &constraints_str
}
}

impl EvalAtRow for ExprEvaluator {
Expand Down Expand Up @@ -279,7 +310,12 @@ impl EvalAtRow for ExprEvaluator {
where
Self::EF: std::ops::Mul<G, Output = Self::EF>,
{
self.constraints.push(Expr::one() * constraint);
if let Expr::Mul(one, constraint) = Expr::one() * constraint {
assert_eq!(*one, Expr::one());
self.constraints.push(*constraint);
} else {
unreachable!();
}
}

fn combine_ef(values: [Self::F; 4]) -> Self::EF {
Expand All @@ -303,7 +339,8 @@ impl EvalAtRow for ExprEvaluator {
multiplicity,
values,
}| {
Fraction::new(multiplicity.clone(), combine_formal(*relation, values))
let temp_var = self.add_temp_var(combine_formal(*relation, values));
Fraction::new(multiplicity.clone(), temp_var)
},
)
.collect();
Expand All @@ -317,187 +354,34 @@ impl EvalAtRow for ExprEvaluator {
mod tests {
use num_traits::One;

use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkEval, RelationEntry, ORIGINAL_TRACE_IDX,
};
use crate::core::fields::m31::M31;
use crate::constraint_framework::expr::ExprEvaluator;
use crate::constraint_framework::{relation, EvalAtRow, FrameworkEval, RelationEntry};
use crate::core::fields::FieldExpOps;

#[test]
fn test_expr_eval() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
assert_eq!(eval.constraints.len(), 2);
assert_eq!(
eval.constraints[0],
Expr::Mul(
Box::new(Expr::one()),
Box::new(Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Col(ColumnExpr {
interaction: ORIGINAL_TRACE_IDX,
idx: 0,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: ORIGINAL_TRACE_IDX,
idx: 1,
offset: 0
}))
)),
Box::new(Expr::Col(ColumnExpr {
interaction: ORIGINAL_TRACE_IDX,
idx: 2,
offset: 0
}))
)),
Box::new(Expr::Inv(Box::new(Expr::Add(
Box::new(Expr::Col(ColumnExpr {
interaction: ORIGINAL_TRACE_IDX,
idx: 0,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: ORIGINAL_TRACE_IDX,
idx: 1,
offset: 0
}))
))))
))
)
);

assert_eq!(
eval.constraints[1],
Expr::Mul(
Box::new(Expr::Const(M31(1))),
Box::new(Expr::Sub(
Box::new(Expr::Mul(
Box::new(Expr::Sub(
Box::new(Expr::Sub(
Box::new(Expr::SecureCol([
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 4,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 6,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 8,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 10,
offset: 0
}))
])),
Box::new(Expr::Sub(
Box::new(Expr::SecureCol([
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 5,
offset: -1
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 7,
offset: -1
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 9,
offset: -1
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 2,
idx: 11,
offset: -1
}))
])),
Box::new(Expr::Mul(
Box::new(Expr::Col(ColumnExpr {
interaction: 0,
idx: 3,
offset: 0
})),
Box::new(Expr::Param("total_sum".into()))
))
))
)),
Box::new(Expr::Const(M31(0)))
)),
Box::new(Expr::Sub(
Box::new(Expr::Add(
Box::new(Expr::Add(
Box::new(Expr::Add(
Box::new(Expr::Const(M31(0))),
Box::new(Expr::Mul(
Box::new(Expr::Param(
"TestRelation_alpha0".to_string()
)),
Box::new(Expr::Col(ColumnExpr {
interaction: 1,
idx: 0,
offset: 0
}))
))
)),
Box::new(Expr::Mul(
Box::new(Expr::Param("TestRelation_alpha1".to_string())),
Box::new(Expr::Col(ColumnExpr {
interaction: 1,
idx: 1,
offset: 0
}))
))
)),
Box::new(Expr::Mul(
Box::new(Expr::Param("TestRelation_alpha2".to_string())),
Box::new(Expr::Col(ColumnExpr {
interaction: 1,
idx: 2,
offset: 0
}))
))
)),
Box::new(Expr::Param("TestRelation_z".to_string()))
))
)),
Box::new(Expr::Const(M31(1)))
))
)
);
}

#[test]
fn test_format_expr() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let constraint0_str = "(1) * ((((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1/(col_1_0[0] + col_1_1[0])))";
assert_eq!(eval.constraints[0].format_expr(), constraint0_str);
let constraint1_str = "(1) \
* ((SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \
- (SecureCol(\
col_2_5[-1], \
col_2_7[-1], \
col_2_9[-1], \
col_2_11[-1]\
) - ((col_0_3[0]) * (total_sum))) \
- (0)) \
* (0 + (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
+ (TestRelation_alpha2) * (col_1_2[0]) \
- (TestRelation_z)) \
- (1))";
assert_eq!(eval.constraints[1].format_expr(), constraint1_str);
let expected = "let temp_0 = 0 \
+ (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
+ (TestRelation_alpha2) * (col_1_2[0]) \
- (TestRelation_z);
\
let constraint_0 = \
(((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1 / (col_1_0[0] + col_1_1[0]));
\
let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \
- (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \
- ((col_0_3[0]) * (total_sum))) \
- (0)) \
* (temp_0) \
- (1);"
.to_string();

assert_eq!(eval.format_constraints(), expected);
}

relation!(TestRelation, 3);
Expand Down
62 changes: 29 additions & 33 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,38 +293,34 @@ mod tests {
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));

assert_eq!(eval.constraints.len(), 2);
let constraint0_str = "(1) \
* ((SecureCol(\
col_2_5[claimed_sum_offset], \
col_2_8[claimed_sum_offset], \
col_2_11[claimed_sum_offset], \
col_2_14[claimed_sum_offset]\
) - (claimed_sum)) \
* (col_0_2[0]))";
assert_eq!(eval.constraints[0].format_expr(), constraint0_str);
let constraint1_str = "(1) \
* ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \
- (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \
- ((col_0_2[0]) * (total_sum))) \
- (0)) \
* ((0 \
+ (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z)) \
* (0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z))) \
- ((0 \
+ (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z)) \
* (1) \
+ (0 + (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z)) \
* (-(1))))";
assert_eq!(eval.constraints[1].format_expr(), constraint1_str);
let expected = "let temp_0 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
\
let temp_1 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
\
let constraint_0 = (SecureCol(\
col_2_5[claimed_sum_offset], \
col_2_8[claimed_sum_offset], \
col_2_11[claimed_sum_offset], \
col_2_14[claimed_sum_offset]\
) - (claimed_sum)) \
* (col_0_2[0]);
\
let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \
- (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \
- ((col_0_2[0]) * (total_sum))) \
- (0)) \
* ((temp_0) * (temp_1)) \
- ((temp_1) * (1) + (temp_0) * (-(1)));"
.to_string();

assert_eq!(eval.format_constraints(), expected);
}
}

0 comments on commit 22d6538

Please sign in to comment.