Skip to content

Commit

Permalink
Simplify expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 19, 2024
1 parent 22d6538 commit 1091abc
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 13 deletions.
97 changes: 91 additions & 6 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,90 @@ impl AddAssign<BaseField> for Expr {
}
}

pub fn simplify(expr: Expr) -> Expr {
match expr {
Expr::Add(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) {
Expr::Const(a + b)
} else if a == Expr::zero() {
b
} else if b == Expr::zero() {
a
} else if let Expr::Neg(a) = a {
if let Expr::Neg(b) = b {
-(*a + *b)
} else {
b - *a
}
} else if let Expr::Neg(b) = b {
a - *b
} else {
a + b
}
}
Expr::Sub(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
if a == Expr::zero() {
-b
} else if b == Expr::zero() {
a
} else if a == b {
Expr::zero()
} else {
a - b
}
}
Expr::Mul(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) {
Expr::Const(a - b)
} else if a == Expr::zero() || b == Expr::zero() {
Expr::zero()
} else if a == Expr::one() {
b
} else if b == Expr::one() {
a
} else if a == -Expr::one() {
-b
} else if b == -Expr::one() {
-a
} else {
a * b
}
}
Expr::Col(colexpr) => Expr::Col(colexpr),
Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([
Box::new(simplify(*a)),
Box::new(simplify(*b)),
Box::new(simplify(*c)),
Box::new(simplify(*d)),
]),
Expr::Const(c) => Expr::Const(c),
Expr::Param(x) => Expr::Param(x),
Expr::Neg(a) => {
let a = simplify(*a);
match a {
Expr::Neg(b) => *b,
Expr::Const(c) => Expr::Const(-c),
Expr::Sub(a, b) => Expr::Sub(b, a),
_ => -a,
}
}
Expr::Inv(a) => {
let a = simplify(*a);
match a {
Expr::Inv(b) => *b,
Expr::Const(c) => Expr::Const(c.inverse()),
_ => Expr::Inv(Box::new(a)),
}
}
}
}

/// Returns the expression
/// `value[0] * <relation>_alpha0 + value[1] * <relation>_alpha1 + ... - <relation>_z.`
fn combine_formal<R: Relation<Expr, Expr>>(relation: &R, values: &[Expr]) -> Expr {
Expand Down Expand Up @@ -269,15 +353,17 @@ impl ExprEvaluator {
let lets_string = self
.temp_vars
.iter()
.map(|(name, expr)| format!("let {} = {};", name, expr.format_expr()))
.map(|(name, expr)| format!("let {} = {};", name, simplify(expr.clone()).format_expr()))
.collect::<Vec<String>>()
.join("\n");

let constraints_str = self
.constraints
.iter()
.enumerate()
.map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";")
.map(|(i, c)| {
format!("let constraint_{i} = ") + &simplify(c.clone()).format_expr() + ";"
})
.collect::<Vec<String>>()
.join("\n\n");

Expand Down Expand Up @@ -362,8 +448,7 @@ mod tests {
fn test_format_expr() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let expected = "let temp_0 = 0 \
+ (TestRelation_alpha0) * (col_1_0[0]) \
let expected = "let temp_0 = (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
+ (TestRelation_alpha2) * (col_1_2[0]) \
- (TestRelation_z);
Expand All @@ -375,8 +460,8 @@ mod tests {
\
let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \
- (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \
- ((col_0_3[0]) * (total_sum))) \
- (0)) \
- ((col_0_3[0]) * (total_sum)))\
) \
* (temp_0) \
- (1);"
.to_string();
Expand Down
12 changes: 5 additions & 7 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,11 @@ mod tests {
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));
let expected = "let temp_0 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0]) \
let expected = "let temp_0 = (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
\
let temp_1 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
let temp_1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
Expand All @@ -315,10 +313,10 @@ mod tests {
\
let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \
- (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \
- ((col_0_2[0]) * (total_sum))) \
- (0)) \
- ((col_0_2[0]) * (total_sum)))\
) \
* ((temp_0) * (temp_1)) \
- ((temp_1) * (1) + (temp_0) * (-(1)));"
- (temp_1 + (temp_0) * (2147483646));"
.to_string();

assert_eq!(eval.format_constraints(), expected);
Expand Down

0 comments on commit 1091abc

Please sign in to comment.