Skip to content

Commit

Permalink
Added LogupAtRow functionality to EvalAtRow. (#875)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti authored Nov 17, 2024
1 parent 884d161 commit 7d3591f
Show file tree
Hide file tree
Showing 17 changed files with 189 additions and 120 deletions.
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
use crate::core::utils::circle_domain_order_to_coset_order;
Expand All @@ -14,13 +16,15 @@ pub struct AssertEvaluator<'a> {
pub trace: &'a TreeVec<Vec<Vec<BaseField>>>,
pub col_index: TreeVec<usize>,
pub row: usize,
pub logup: LogupAtRow<Self>,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -57,6 +61,8 @@ impl<'a> EvalAtRow for AssertEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}

super::logup_proxy!();
}

pub fn assert_constraints<B: Backend>(
Expand Down
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
Expand All @@ -22,6 +24,7 @@ pub struct CpuDomainEvaluator<'a> {
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
pub logup: LogupAtRow<Self>,
}

impl<'a> CpuDomainEvaluator<'a> {
Expand All @@ -42,6 +45,7 @@ impl<'a> CpuDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -88,4 +92,6 @@ impl<'a> EvalAtRow for CpuDomainEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}

super::logup_proxy!();
}
9 changes: 7 additions & 2 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`.
#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -149,8 +151,9 @@ impl AddAssign<BaseField> for Expr {
/// An Evaluator that saves all constraint expressions.
#[derive(Default)]
struct ExprEvaluator {
cur_var_index: usize,
constraints: Vec<Expr>,
pub cur_var_index: usize,
pub constraints: Vec<Expr>,
pub logup: LogupAtRow<Self>,
}

impl EvalAtRow for ExprEvaluator {
Expand Down Expand Up @@ -189,6 +192,8 @@ impl EvalAtRow for ExprEvaluator {
Box::new(values[3].clone()),
])
}

super::logup_proxy!();
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::ops::Mul;

use num_traits::One;

use super::logup::LogupAtRow;
use super::preprocessed_columns::PreprocessedColumn;
use super::EvalAtRow;
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;

/// Collects information about the constraints.
Expand All @@ -16,6 +18,7 @@ pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
pub logup: LogupAtRow<Self>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Expand Down Expand Up @@ -60,4 +63,6 @@ impl EvalAtRow for InfoEvaluator {
fn combine_ef(_values: [Self::F; 4]) -> Self::EF {
SecureField::one()
}

super::logup_proxy!();
}
21 changes: 20 additions & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ pub struct LogupAtRow<E: EvalAtRow> {
/// See [`super::preprocessed_columns::gen_is_first()`].
pub is_first: E::F,
}

impl<E: EvalAtRow> Default for LogupAtRow<E> {
fn default() -> Self {
Self::dummy()
}
}
impl<E: EvalAtRow> LogupAtRow<E> {
pub fn new(
interaction: usize,
Expand All @@ -62,6 +68,19 @@ impl<E: EvalAtRow> LogupAtRow<E> {
}
}

// TODO(alont): Remove this once unnecessary LogupAtRows are gone.
pub fn dummy() -> Self {
Self {
interaction: 100,
total_sum: SecureField::one(),
claimed_sum: None,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: true,
is_first: E::F::zero(),
}
}

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac.clone() {
Expand All @@ -73,7 +92,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
self.cur_frac = Some(fraction);
}

pub fn finalize(mut self, eval: &mut E) {
pub fn finalize(&mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.clone().unwrap();
Expand Down
56 changes: 56 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

pub const PREPROCESSED_TRACE_IDX: usize = 0;
pub const ORIGINAL_TRACE_IDX: usize = 1;
Expand Down Expand Up @@ -109,4 +110,59 @@ pub trait EvalAtRow {

/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;

// TODO(alont): Remove these once LogupAtRow is no longer used.
fn init_logup(
&mut self,
_total_sum: SecureField,
_claimed_sum: Option<crate::constraint_framework::logup::ClaimedPrefixSum>,
_log_size: u32,
) {
unimplemented!()
}
fn write_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {
unimplemented!()
}
fn finalize_logup(&mut self) {
unimplemented!()
}
}

/// Default implementation for evaluators that have an element called "logup" that works like a
/// LogupAtRow, where the logup functionality can be proxied.
/// TODO(alont): Remove once LogupAtRow is no longer used.
macro_rules! logup_proxy {
() => {
fn init_logup(
&mut self,
total_sum: SecureField,
claimed_sum: Option<crate::constraint_framework::logup::ClaimedPrefixSum>,
log_size: u32,
) {
let is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
log_size,
),
);
self.logup = crate::constraint_framework::logup::LogupAtRow::new(
crate::constraint_framework::INTERACTION_TRACE_IDX,
total_sum,
claimed_sum,
is_first,
);
}

fn write_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
let mut logup = std::mem::take(&mut self.logup);
logup.write_frac(self, fraction);
self.logup = logup;
}

fn finalize_logup(&mut self) {
let mut logup = std::mem::take(&mut self.logup);
logup.finalize(self);
self.logup = logup;
}
};
}
pub(crate) use logup_proxy;
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ops::Mul;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::ColumnVec;

Expand All @@ -13,6 +15,7 @@ pub struct PointEvaluator<'a> {
pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
pub col_index: Vec<usize>,
pub denom_inverse: SecureField,
pub logup: LogupAtRow<Self>,
}
impl<'a> PointEvaluator<'a> {
pub fn new(
Expand All @@ -26,6 +29,7 @@ impl<'a> PointEvaluator<'a> {
evaluation_accumulator,
col_index,
denom_inverse,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -54,4 +58,6 @@ impl<'a> EvalAtRow for PointEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_partial_evals(values)
}

super::logup_proxy!();
}
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/simd_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::backend::simd::column::VeryPackedBaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand All @@ -13,6 +14,7 @@ use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
Expand All @@ -30,6 +32,7 @@ pub struct SimdDomainEvaluator<'a> {
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
pub logup: LogupAtRow<Self>,
}
impl<'a> SimdDomainEvaluator<'a> {
pub fn new(
Expand All @@ -48,6 +51,7 @@ impl<'a> SimdDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -103,4 +107,6 @@ impl<'a> EvalAtRow for SimdDomainEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
VeryPackedSecureField::from_very_packed_m31s(values)
}

super::logup_proxy!();
}
35 changes: 17 additions & 18 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use itertools::{chain, Itertools};
use num_traits::One;

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::LogupAtRow;
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::{Fraction, Reciprocal};
use crate::examples::blake::{Fu32, STATE_SIZE};

Expand All @@ -15,10 +15,12 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> {
pub eval: E,
pub xor_lookup_elements: &'a BlakeXorElements,
pub round_lookup_elements: &'a RoundElements,
pub logup: LogupAtRow<E>,
pub total_sum: SecureField,
pub log_size: u32,
}
impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
pub fn eval(mut self) -> E {
self.eval.init_logup(self.total_sum, None, self.log_size);
let mut v: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());
let input_v = v.clone();
let m: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());
Expand Down Expand Up @@ -65,22 +67,19 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
);

// Yield `Round(input_v, output_v, message)`.
self.logup.write_frac(
&mut self.eval,
Fraction::new(
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
self.eval.write_frac(Fraction::new(
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
);
));

self.logup.finalize(&mut self.eval);
self.eval.finalize_logup();
self.eval
}
fn next_u32(&mut self) -> Fu32<E::F> {
Expand Down Expand Up @@ -197,9 +196,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
lookup_elements.combine::<E::F, E::EF>(&[a[0].clone(), b[0].clone(), c[0].clone()]);
let comb1 =
lookup_elements.combine::<E::F, E::EF>(&[a[1].clone(), b[1].clone(), c[1].clone()]);
let frac = Reciprocal::new(comb0) + Reciprocal::new(comb1);

self.logup.write_frac(&mut self.eval, frac);
self.eval
.write_frac(Reciprocal::new(comb0) + Reciprocal::new(comb1));
c
}
}
Loading

0 comments on commit 7d3591f

Please sign in to comment.