Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added LogupAtRow functionality to EvalAtRow. #875

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
use crate::core::utils::circle_domain_order_to_coset_order;
Expand All @@ -14,13 +16,15 @@ pub struct AssertEvaluator<'a> {
pub trace: &'a TreeVec<Vec<Vec<BaseField>>>,
pub col_index: TreeVec<usize>,
pub row: usize,
pub logup: LogupAtRow<Self>,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -57,6 +61,8 @@ impl<'a> EvalAtRow for AssertEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}

super::logup_proxy!();
}

pub fn assert_constraints<B: Backend>(
Expand Down
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
Expand All @@ -22,6 +24,7 @@ pub struct CpuDomainEvaluator<'a> {
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
pub logup: LogupAtRow<Self>,
}

impl<'a> CpuDomainEvaluator<'a> {
Expand All @@ -42,6 +45,7 @@ impl<'a> CpuDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -88,4 +92,6 @@ impl<'a> EvalAtRow for CpuDomainEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}

super::logup_proxy!();
}
9 changes: 7 additions & 2 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`.
#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -149,8 +151,9 @@ impl AddAssign<BaseField> for Expr {
/// An Evaluator that saves all constraint expressions.
#[derive(Default)]
struct ExprEvaluator {
cur_var_index: usize,
constraints: Vec<Expr>,
pub cur_var_index: usize,
pub constraints: Vec<Expr>,
pub logup: LogupAtRow<Self>,
}

impl EvalAtRow for ExprEvaluator {
Expand Down Expand Up @@ -189,6 +192,8 @@ impl EvalAtRow for ExprEvaluator {
Box::new(values[3].clone()),
])
}

super::logup_proxy!();
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::ops::Mul;

use num_traits::One;

use super::logup::LogupAtRow;
use super::preprocessed_columns::PreprocessedColumn;
use super::EvalAtRow;
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;

/// Collects information about the constraints.
Expand All @@ -16,6 +18,7 @@ pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
pub logup: LogupAtRow<Self>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Expand Down Expand Up @@ -60,4 +63,6 @@ impl EvalAtRow for InfoEvaluator {
fn combine_ef(_values: [Self::F; 4]) -> Self::EF {
SecureField::one()
}

super::logup_proxy!();
}
21 changes: 20 additions & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ pub struct LogupAtRow<E: EvalAtRow> {
/// See [`super::preprocessed_columns::gen_is_first()`].
pub is_first: E::F,
}

impl<E: EvalAtRow> Default for LogupAtRow<E> {
fn default() -> Self {
Self::dummy()
}
}
impl<E: EvalAtRow> LogupAtRow<E> {
pub fn new(
interaction: usize,
Expand All @@ -62,6 +68,19 @@ impl<E: EvalAtRow> LogupAtRow<E> {
}
}

// TODO(alont): Remove this once unnecessary LogupAtRows are gone.
pub fn dummy() -> Self {
Self {
interaction: 100,
total_sum: SecureField::one(),
claimed_sum: None,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: true,
is_first: E::F::zero(),
}
}

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac.clone() {
Expand All @@ -73,7 +92,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
self.cur_frac = Some(fraction);
}

pub fn finalize(mut self, eval: &mut E) {
pub fn finalize(&mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.clone().unwrap();
Expand Down
56 changes: 56 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

pub const PREPROCESSED_TRACE_IDX: usize = 0;
pub const ORIGINAL_TRACE_IDX: usize = 1;
Expand Down Expand Up @@ -109,4 +110,59 @@ pub trait EvalAtRow {

/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;

// TODO(alont): Remove these once LogupAtRow is no longer used.
fn init_logup(
&mut self,
_total_sum: SecureField,
_claimed_sum: Option<crate::constraint_framework::logup::ClaimedPrefixSum>,
_log_size: u32,
) {
unimplemented!()
}
fn write_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {
unimplemented!()
}
fn finalize_logup(&mut self) {
unimplemented!()
}
}

/// Default implementation for evaluators that have an element called "logup" that works like a
/// LogupAtRow, where the logup functionality can be proxied.
/// TODO(alont): Remove once LogupAtRow is no longer used.
macro_rules! logup_proxy {
() => {
fn init_logup(
&mut self,
total_sum: SecureField,
claimed_sum: Option<crate::constraint_framework::logup::ClaimedPrefixSum>,
log_size: u32,
) {
let is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
log_size,
),
);
self.logup = crate::constraint_framework::logup::LogupAtRow::new(
crate::constraint_framework::INTERACTION_TRACE_IDX,
total_sum,
claimed_sum,
is_first,
);
}

fn write_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
let mut logup = std::mem::take(&mut self.logup);
logup.write_frac(self, fraction);
self.logup = logup;
}

fn finalize_logup(&mut self) {
let mut logup = std::mem::take(&mut self.logup);
logup.finalize(self);
self.logup = logup;
}
};
}
pub(crate) use logup_proxy;
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ops::Mul;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::ColumnVec;

Expand All @@ -13,6 +15,7 @@ pub struct PointEvaluator<'a> {
pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
pub col_index: Vec<usize>,
pub denom_inverse: SecureField,
pub logup: LogupAtRow<Self>,
}
impl<'a> PointEvaluator<'a> {
pub fn new(
Expand All @@ -26,6 +29,7 @@ impl<'a> PointEvaluator<'a> {
evaluation_accumulator,
col_index,
denom_inverse,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -54,4 +58,6 @@ impl<'a> EvalAtRow for PointEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_partial_evals(values)
}

super::logup_proxy!();
}
6 changes: 6 additions & 0 deletions crates/prover/src/constraint_framework/simd_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use crate::core::backend::simd::column::VeryPackedBaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand All @@ -13,6 +14,7 @@ use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
Expand All @@ -30,6 +32,7 @@ pub struct SimdDomainEvaluator<'a> {
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
pub logup: LogupAtRow<Self>,
}
impl<'a> SimdDomainEvaluator<'a> {
pub fn new(
Expand All @@ -48,6 +51,7 @@ impl<'a> SimdDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
}
}
}
Expand Down Expand Up @@ -103,4 +107,6 @@ impl<'a> EvalAtRow for SimdDomainEvaluator<'a> {
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
VeryPackedSecureField::from_very_packed_m31s(values)
}

super::logup_proxy!();
}
35 changes: 17 additions & 18 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use itertools::{chain, Itertools};
use num_traits::One;

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::LogupAtRow;
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::{Fraction, Reciprocal};
use crate::examples::blake::{Fu32, STATE_SIZE};

Expand All @@ -15,10 +15,12 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> {
pub eval: E,
pub xor_lookup_elements: &'a BlakeXorElements,
pub round_lookup_elements: &'a RoundElements,
pub logup: LogupAtRow<E>,
pub total_sum: SecureField,
pub log_size: u32,
}
impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
pub fn eval(mut self) -> E {
self.eval.init_logup(self.total_sum, None, self.log_size);
let mut v: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());
let input_v = v.clone();
let m: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());
Expand Down Expand Up @@ -65,22 +67,19 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
);

// Yield `Round(input_v, output_v, message)`.
self.logup.write_frac(
&mut self.eval,
Fraction::new(
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
self.eval.write_frac(Fraction::new(
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
);
));

self.logup.finalize(&mut self.eval);
self.eval.finalize_logup();
self.eval
}
fn next_u32(&mut self) -> Fu32<E::F> {
Expand Down Expand Up @@ -197,9 +196,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
lookup_elements.combine::<E::F, E::EF>(&[a[0].clone(), b[0].clone(), c[0].clone()]);
let comb1 =
lookup_elements.combine::<E::F, E::EF>(&[a[1].clone(), b[1].clone(), c[1].clone()]);
let frac = Reciprocal::new(comb0) + Reciprocal::new(comb1);

self.logup.write_frac(&mut self.eval, frac);
self.eval
.write_frac(Reciprocal::new(comb0) + Reciprocal::new(comb1));
c
}
}
Loading
Loading