Skip to content

Commit

Permalink
Initialize logups automatically when first used.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 17, 2024
1 parent 85bc3d8 commit 92fe59a
Show file tree
Hide file tree
Showing 23 changed files with 305 additions and 228 deletions.
17 changes: 12 additions & 5 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -19,12 +19,17 @@ pub struct AssertEvaluator<'a> {
pub logup: LogupAtRow<Self>,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
pub fn new(
trace: &'a TreeVec<Vec<Vec<BaseField>>>,
row: usize,
log_size: u32,
logup_sums: LogupSums,
) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
logup: LogupAtRow::dummy(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}
Expand Down Expand Up @@ -69,6 +74,7 @@ pub fn assert_constraints<B: Backend>(
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
trace_domain: CanonicCoset,
assert_func: impl Fn(AssertEvaluator<'_>),
logup_sums: LogupSums,
) {
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
Expand All @@ -84,7 +90,8 @@ pub fn assert_constraints<B: Backend>(
.collect()
});
for row in 0..trace_domain.size() {
let eval = AssertEvaluator::new(&traces, row);
let eval = AssertEvaluator::new(&traces, row, trace_domain.log_size(), logup_sums);

assert_func(eval);
}
}
17 changes: 15 additions & 2 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rayon::prelude::*;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
Expand Down Expand Up @@ -113,11 +114,16 @@ pub struct FrameworkComponent<C: FrameworkEval> {
trace_locations: TreeVec<TreeSubspan>,
info: InfoEvaluator,
preprocessed_column_indices: Vec<usize>,
logup_sums: LogupSums,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let info = eval.evaluate(InfoEvaluator::default());
pub fn new(
location_allocator: &mut TraceLocationAllocator,
eval: E,
logup_sums: LogupSums,
) -> Self {
let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], logup_sums));
let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);

let preprocessed_column_indices = info
Expand Down Expand Up @@ -148,6 +154,7 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
trace_locations,
info,
preprocessed_column_indices,
logup_sums,
}
}

Expand Down Expand Up @@ -217,6 +224,8 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
mask_points,
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
self.eval.log_size(),
self.logup_sums,
));
}
}
Expand Down Expand Up @@ -296,6 +305,8 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
self.logup_sums,
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down Expand Up @@ -333,6 +344,8 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
self.logup_sums,
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down
8 changes: 5 additions & 3 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -35,6 +35,8 @@ impl<'a> CpuDomainEvaluator<'a> {
random_coeff_powers: &'a [SecureField],
domain_log_size: u32,
eval_log_size: u32,
log_size: u32,
logup_sums: LogupSums,
) -> Self {
Self {
trace_eval,
Expand All @@ -45,7 +47,7 @@ impl<'a> CpuDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
Expand Down Expand Up @@ -156,6 +156,16 @@ struct ExprEvaluator {
pub logup: LogupAtRow<Self>,
}

impl ExprEvaluator {
pub fn _new(log_size: u32, logup_sums: LogupSums) -> Self {
Self {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}

impl EvalAtRow for ExprEvaluator {
// TODO(alont): Should there be a version of this that disallows Secure fields for F?
type F = Expr;
Expand Down
23 changes: 19 additions & 4 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::ops::Mul;

use num_traits::One;

use super::logup::LogupAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::preprocessed_columns::PreprocessedColumn;
use super::EvalAtRow;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -21,8 +21,23 @@ pub struct InfoEvaluator {
pub logup: LogupAtRow<Self>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Self::default()
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreprocessedColumn>,
logup_sums: LogupSums,
) -> Self {
Self {
mask_offsets: Default::default(),
n_constraints: Default::default(),
preprocessed_columns,
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}

/// Create an empty `InfoEvaluator`, to measure components before their size and logup sums are
/// available.
pub fn empty() -> Self {
Self::new(16, vec![], (SecureField::default(), None))
}
}
impl EvalAtRow for InfoEvaluator {
Expand Down
81 changes: 11 additions & 70 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use crate::core::ColumnVec;
/// Represents the value of the prefix sum column at some index.
/// Should be used to eliminate padded rows for the logup sum.
pub type ClaimedPrefixSum = (SecureField, usize);
// (total_sum, claimed_sum)
pub type LogupSums = (SecureField, Option<ClaimedPrefixSum>);

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
Expand All @@ -38,11 +40,12 @@ pub struct LogupAtRow<E: EvalAtRow> {
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
cur_frac: Option<Fraction<E::EF, E::EF>>,
is_finalized: bool,
pub cur_frac: Option<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
pub is_first: E::F,
pub log_size: u32,
}

impl<E: EvalAtRow> Default for LogupAtRow<E> {
Expand All @@ -55,16 +58,17 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction: usize,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
is_first: E::F,
log_size: u32,
) -> Self {
Self {
interaction,
total_sum,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: false,
is_first,
is_finalized: true,
is_first: E::F::zero(),
log_size,
}
}

Expand All @@ -78,53 +82,9 @@ impl<E: EvalAtRow> LogupAtRow<E> {
cur_frac: None,
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
}
}

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac.clone() {
let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]);
let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone();
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}
self.cur_frac = Some(fraction);
}

pub fn finalize(&mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.clone().unwrap();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset
// from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.claimed_sum {
Some((claimed_sum, claimed_row_index)) => {
let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = eval
.next_extension_interaction_mask(
self.interaction,
[0, -1, claimed_row_index as isize],
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first.clone());
(cur_cumsum, prev_row_cumsum)
}
None => {
let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
(cur_cumsum, prev_row_cumsum)
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first.clone() * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum.clone();

eval.add_constraint(diff * frac.denominator - frac.numerator);

self.is_finalized = true;
}
}

/// Ensures that the LogupAtRow is finalized.
Expand Down Expand Up @@ -314,30 +274,11 @@ impl<'a> LogupColGenerator<'a> {

#[cfg(test)]
mod tests {
use num_traits::One;

use super::{LogupAtRow, LookupElements};
use crate::constraint_framework::{InfoEvaluator, INTERACTION_TRACE_IDX};
use super::LookupElements;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<InfoEvaluator>::new(
INTERACTION_TRACE_IDX,
SecureField::one(),
None,
BaseField::one(),
);
logup.write_frac(
&mut InfoEvaluator::default(),
Fraction::new(SecureField::one(), SecureField::one()),
);
}

#[test]
fn test_lookup_elements_combine() {
Expand Down
Loading

0 comments on commit 92fe59a

Please sign in to comment.