Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize logups automatically when first used. #879

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -19,12 +19,17 @@ pub struct AssertEvaluator<'a> {
pub logup: LogupAtRow<Self>,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
pub fn new(
trace: &'a TreeVec<Vec<Vec<BaseField>>>,
row: usize,
log_size: u32,
logup_sums: LogupSums,
) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
logup: LogupAtRow::dummy(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}
Expand Down Expand Up @@ -69,6 +74,7 @@ pub fn assert_constraints<B: Backend>(
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
trace_domain: CanonicCoset,
assert_func: impl Fn(AssertEvaluator<'_>),
logup_sums: LogupSums,
) {
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
Expand All @@ -84,7 +90,8 @@ pub fn assert_constraints<B: Backend>(
.collect()
});
for row in 0..trace_domain.size() {
let eval = AssertEvaluator::new(&traces, row);
let eval = AssertEvaluator::new(&traces, row, trace_domain.log_size(), logup_sums);

assert_func(eval);
}
}
17 changes: 15 additions & 2 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rayon::prelude::*;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
Expand Down Expand Up @@ -113,11 +114,16 @@ pub struct FrameworkComponent<C: FrameworkEval> {
trace_locations: TreeVec<TreeSubspan>,
info: InfoEvaluator,
preprocessed_column_indices: Vec<usize>,
logup_sums: LogupSums,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let info = eval.evaluate(InfoEvaluator::default());
pub fn new(
location_allocator: &mut TraceLocationAllocator,
eval: E,
logup_sums: LogupSums,
) -> Self {
let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], logup_sums));
let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);

let preprocessed_column_indices = info
Expand Down Expand Up @@ -148,6 +154,7 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
trace_locations,
info,
preprocessed_column_indices,
logup_sums,
}
}

Expand Down Expand Up @@ -217,6 +224,8 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
mask_points,
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
self.eval.log_size(),
self.logup_sums,
));
}
}
Expand Down Expand Up @@ -296,6 +305,8 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
self.logup_sums,
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down Expand Up @@ -333,6 +344,8 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
self.logup_sums,
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down
8 changes: 5 additions & 3 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -35,6 +35,8 @@ impl<'a> CpuDomainEvaluator<'a> {
random_coeff_powers: &'a [SecureField],
domain_log_size: u32,
eval_log_size: u32,
log_size: u32,
logup_sums: LogupSums,
) -> Self {
Self {
trace_eval,
Expand All @@ -45,7 +47,7 @@ impl<'a> CpuDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
Expand Down Expand Up @@ -156,6 +156,16 @@ struct ExprEvaluator {
pub logup: LogupAtRow<Self>,
}

impl ExprEvaluator {
pub fn _new(log_size: u32, logup_sums: LogupSums) -> Self {
Self {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}

impl EvalAtRow for ExprEvaluator {
// TODO(alont): Should there be a version of this that disallows Secure fields for F?
type F = Expr;
Expand Down
23 changes: 19 additions & 4 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::ops::Mul;

use num_traits::One;

use super::logup::LogupAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::preprocessed_columns::PreprocessedColumn;
use super::EvalAtRow;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -21,8 +21,23 @@ pub struct InfoEvaluator {
pub logup: LogupAtRow<Self>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Self::default()
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreprocessedColumn>,
logup_sums: LogupSums,
) -> Self {
Self {
mask_offsets: Default::default(),
n_constraints: Default::default(),
preprocessed_columns,
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}

/// Create an empty `InfoEvaluator`, to measure components before their size and logup sums are
/// available.
pub fn empty() -> Self {
Self::new(16, vec![], (SecureField::default(), None))
}
}
impl EvalAtRow for InfoEvaluator {
Expand Down
81 changes: 11 additions & 70 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use crate::core::ColumnVec;
/// Represents the value of the prefix sum column at some index.
/// Should be used to eliminate padded rows for the logup sum.
pub type ClaimedPrefixSum = (SecureField, usize);
// (total_sum, claimed_sum)
pub type LogupSums = (SecureField, Option<ClaimedPrefixSum>);

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
Expand All @@ -38,11 +40,12 @@ pub struct LogupAtRow<E: EvalAtRow> {
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
cur_frac: Option<Fraction<E::EF, E::EF>>,
is_finalized: bool,
pub cur_frac: Option<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
pub is_first: E::F,
pub log_size: u32,
}

impl<E: EvalAtRow> Default for LogupAtRow<E> {
Expand All @@ -55,16 +58,17 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction: usize,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
is_first: E::F,
log_size: u32,
) -> Self {
Self {
interaction,
total_sum,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: false,
is_first,
is_finalized: true,
is_first: E::F::zero(),
log_size,
}
}

Expand All @@ -78,53 +82,9 @@ impl<E: EvalAtRow> LogupAtRow<E> {
cur_frac: None,
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
}
}

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac.clone() {
let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]);
let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone();
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}
self.cur_frac = Some(fraction);
}

pub fn finalize(&mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.clone().unwrap();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset
// from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.claimed_sum {
Some((claimed_sum, claimed_row_index)) => {
let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = eval
.next_extension_interaction_mask(
self.interaction,
[0, -1, claimed_row_index as isize],
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first.clone());
(cur_cumsum, prev_row_cumsum)
}
None => {
let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
(cur_cumsum, prev_row_cumsum)
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first.clone() * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum.clone();

eval.add_constraint(diff * frac.denominator - frac.numerator);

self.is_finalized = true;
}
}

/// Ensures that the LogupAtRow is finalized.
Expand Down Expand Up @@ -314,30 +274,11 @@ impl<'a> LogupColGenerator<'a> {

#[cfg(test)]
mod tests {
use num_traits::One;

use super::{LogupAtRow, LookupElements};
use crate::constraint_framework::{InfoEvaluator, INTERACTION_TRACE_IDX};
use super::LookupElements;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<InfoEvaluator>::new(
INTERACTION_TRACE_IDX,
SecureField::one(),
None,
BaseField::one(),
);
logup.write_frac(
&mut InfoEvaluator::default(),
Fraction::new(SecureField::one(), SecureField::one()),
);
}

#[test]
fn test_lookup_elements_combine() {
Expand Down
Loading
Loading