Skip to content

Commit

Permalink
Abstracted eval.write_fracs to eval.add_to_relations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 13, 2024
1 parent cf4bb59 commit 670e588
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 68 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<E: EvalAtRow> Drop for LogupAtRow<E> {
pub struct LookupElements<const N: usize> {
pub z: SecureField,
pub alpha: SecureField,
alpha_powers: [SecureField; N],
pub alpha_powers: [SecureField; N],
}
impl<const N: usize> LookupElements<N> {
pub fn draw(channel: &mut impl Channel) -> Self {
Expand Down
110 changes: 110 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,29 @@ pub trait EvalAtRow {
/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;

/// Adds `entry.elems` to `entry.relation` with `entry.multiplicity` for all 'entry' in
/// 'entries', batched together.
/// Constraint degree increases with number of batched constraints as the denominators are
/// multiplied.
fn add_to_relation<Relation: RelationType<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, Relation>],
) {
let fracs: Vec<Fraction<Self::EF, Self::EF>> = entries
.iter()
.map(
|RelationEntry {
relation,
multiplicity,
values: elems,
}| {
Fraction::new(multiplicity.clone(), relation.combine(elems))
},
)
.collect();
self.write_frac(fracs.into_iter().sum());
}

// TODO(alont): Remove these once LogupAtRow is no longer used.
fn init_logup(
&mut self,
Expand Down Expand Up @@ -166,3 +189,90 @@ macro_rules! logup_proxy {
};
}
pub(crate) use logup_proxy;

pub trait RelationEFTraitBound<F: Clone>:
Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = Self> + Sub<Self, Output = Self>
{
}

impl<F, EF> RelationEFTraitBound<F> for EF
where
F: Clone,
EF: Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
{
}

/// A trait for defining a logup relation type.
pub trait RelationType<F: Clone, EF: RelationEFTraitBound<F>>: Sized {
fn combine(&self, values: &[F]) -> EF {
values
.iter()
.zip(self.get_alpha_powers())
.fold(EF::zero(), |acc, (value, power)| {
acc + EF::from(*power) * value.clone()
})
- self.get_z().into()
}

fn get_z(&self) -> SecureField;
fn get_alpha_powers(&self) -> &[SecureField];
fn get_name(&self) -> &str;
}

/// A struct representing a relation entry.
/// `relation` is the relation into which elements are entered.
/// `multiplicity` is the multiplicity of the elements.
/// A positive multiplicity is used to signify a "use", while a negative multiplicity
/// signifies a "yield".
/// `values` are elements in the base field that are entered into the relation.
pub struct RelationEntry<'a, F: Clone, EF: RelationEFTraitBound<F>, Relation: RelationType<F, EF>> {
relation: &'a Relation,
multiplicity: EF,
values: &'a [F],
}
impl<'a, F: Clone, EF: RelationEFTraitBound<F>, Relation: RelationType<F, EF>>
RelationEntry<'a, F, EF, Relation>
{
pub fn new(relation: &'a Relation, multiplicity: EF, elems: &'a [F]) -> Self {
Self {
relation,
multiplicity,
values: elems,
}
}
}

macro_rules! relation {
($name:tt, $size:tt) => {
#[derive(Clone, Debug, PartialEq)]
pub struct $name(crate::constraint_framework::logup::LookupElements<$size>);

impl $name {
pub fn dummy() -> Self {
Self(crate::constraint_framework::logup::LookupElements::dummy())
}
pub fn draw(channel: &mut impl crate::core::channel::Channel) -> Self {
Self(crate::constraint_framework::logup::LookupElements::draw(
channel,
))
}
}

impl<F: Clone, EF: crate::constraint_framework::RelationEFTraitBound<F>>
crate::constraint_framework::RelationType<F, EF> for $name
{
fn get_z(&self) -> crate::core::fields::qm31::SecureField {
self.0.z
}

fn get_alpha_powers(&self) -> &[crate::core::fields::qm31::SecureField] {
&self.0.alpha_powers
}

fn get_name(&self) -> &str {
stringify!($name)
}
}
};
}
pub(crate) use relation;
1 change: 1 addition & 0 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl XorAccums {
}
}

// TODO(alont): Get these out of the struct and give them names.
#[derive(Clone)]
pub struct BlakeXorElements {
xor12: XorElements,
Expand Down
23 changes: 11 additions & 12 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use itertools::{chain, Itertools};
use num_traits::One;

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::EvalAtRow;
use crate::constraint_framework::{EvalAtRow, RelationEntry};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::{Fraction, Reciprocal};
use crate::core::lookups::utils::Reciprocal;
use crate::examples::blake::{Fu32, STATE_SIZE};

const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15);
Expand Down Expand Up @@ -67,17 +67,16 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
);

// Yield `Round(input_v, output_v, message)`.
self.eval.write_frac(Fraction::new(
self.eval.add_to_relation(&[RelationEntry::new(
self.round_lookup_elements,
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
));
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
)]);

self.eval.finalize_logup();
self.eval
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/round/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{span, Level};

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::ORIGINAL_TRACE_IDX;
use crate::constraint_framework::{RelationType, ORIGINAL_TRACE_IDX};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down
7 changes: 4 additions & 3 deletions crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput};
use num_traits::Zero;

use super::{BlakeXorElements, N_ROUND_INPUT_FELTS};
use crate::constraint_framework::logup::LookupElements;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator,
};
use crate::core::fields::qm31::SecureField;

pub type BlakeRoundComponent = FrameworkComponent<BlakeRoundEval>;

pub type RoundElements = LookupElements<N_ROUND_INPUT_FELTS>;
relation!(RoundElements, N_ROUND_INPUT_FELTS);

pub struct BlakeRoundEval {
pub log_size: u32,
Expand Down
27 changes: 14 additions & 13 deletions crates/prover/src/examples/blake/scheduler/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use itertools::{chain, Itertools};
use num_traits::Zero;
use num_traits::{One, Zero};

use super::BlakeElements;
use crate::constraint_framework::EvalAtRow;
use crate::constraint_framework::{EvalAtRow, RelationEntry, RelationType};
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::{Fraction, Reciprocal};
use crate::core::lookups::utils::Fraction;
use crate::core::vcs::blake2s_ref::SIGMA;
use crate::examples::blake::round::RoundElements;
use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE};
Expand All @@ -24,20 +24,21 @@ pub fn eval_blake_scheduler_constraints<E: EvalAtRow>(
// Schedule.
for [i, j] in (0..N_ROUNDS).array_chunks::<2>() {
// Use triplet in round lookup.
let [denom_i, denom_j] = [i, j].map(|idx| {
let [elems_i, elems_j] = [i, j].map(|idx| {
let input_state = &states[idx];
let output_state = &states[idx + 1];
let round_messages = SIGMA[idx].map(|k| messages[k as usize].clone());
round_lookup_elements.combine::<E::F, E::EF>(
&chain![
input_state.iter().cloned().flat_map(Fu32::to_felts),
output_state.iter().cloned().flat_map(Fu32::to_felts),
round_messages.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
)
chain![
input_state.iter().cloned().flat_map(Fu32::to_felts),
output_state.iter().cloned().flat_map(Fu32::to_felts),
round_messages.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec()
});
eval.write_frac(Reciprocal::new(denom_i) + Reciprocal::new(denom_j));
eval.add_to_relation(&[
RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_i),
RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_j),
]);
}

let input_state = &states[0];
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/scheduler/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tracing::{span, Level};

use super::{blake_scheduler_info, BlakeElements};
use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::ORIGINAL_TRACE_IDX;
use crate::constraint_framework::{RelationType, ORIGINAL_TRACE_IDX};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down
7 changes: 4 additions & 3 deletions crates/prover/src/examples/blake/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ use num_traits::Zero;

use super::round::RoundElements;
use super::N_ROUND_INPUT_FELTS;
use crate::constraint_framework::logup::LookupElements;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator,
};
use crate::core::fields::qm31::SecureField;

pub type BlakeSchedulerComponent = FrameworkComponent<BlakeSchedulerEval>;

pub type BlakeElements = LookupElements<N_ROUND_INPUT_FELTS>;
relation!(BlakeElements, N_ROUND_INPUT_FELTS);

pub struct BlakeSchedulerEval {
pub log_size: u32,
Expand Down
35 changes: 18 additions & 17 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use tracing::{span, Level};
use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::{
assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
TraceLocationAllocator,
};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand All @@ -15,7 +16,6 @@ use crate::core::backend::Column;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
Expand All @@ -25,10 +25,13 @@ use crate::core::ColumnVec;

pub type PlonkComponent = FrameworkComponent<PlonkEval>;

// TODO(alont): Rename this and all other `LookupElements` types to `Relation`.
relation!(PlonkLookupElements, 2);

#[derive(Clone)]
pub struct PlonkEval {
pub log_n_rows: u32,
pub lookup_elements: LookupElements<2>,
pub lookup_elements: PlonkLookupElements,
pub claimed_sum: ClaimedPrefixSum,
pub total_sum: SecureField,
pub base_trace_location: TreeSubspan,
Expand Down Expand Up @@ -65,17 +68,16 @@ impl FrameworkEval for PlonkEval {
+ (E::F::one() - op) * a_val.clone() * b_val.clone(),
);

let denom_a: E::EF = self.lookup_elements.combine(&[a_wire, a_val]);
let denom_b: E::EF = self.lookup_elements.combine(&[b_wire, b_val]);
eval.add_to_relation(&[
RelationEntry::new(&self.lookup_elements, E::EF::one(), &[a_wire, a_val]),
RelationEntry::new(&self.lookup_elements, E::EF::one(), &[b_wire, b_val]),
]);

eval.write_frac(Fraction::new(
denom_a.clone() + denom_b.clone(),
denom_a * denom_b,
));
eval.write_frac(Fraction::new(
eval.add_to_relation(&[RelationEntry::new(
&self.lookup_elements,
(-mult).into(),
self.lookup_elements.combine(&[c_wire, c_val]),
));
&[c_wire, c_val],
)]);

eval.finalize_logup();
eval
Expand Down Expand Up @@ -218,12 +220,12 @@ pub fn prove_fibonacci_plonk(
span.exit();

// Draw lookup element.
let lookup_elements = LookupElements::draw(channel);
let lookup_elements = PlonkLookupElements::draw(channel);

// Interaction trace.
let span = span!(Level::INFO, "Interaction").entered();
let (trace, [total_sum, claimed_sum]) =
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements);
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements.0);
let mut tree_builder = commitment_scheme.tree_builder();
let interaction_trace_location = tree_builder.extend_evals(trace);
tree_builder.commit(channel);
Expand Down Expand Up @@ -261,14 +263,13 @@ pub fn prove_fibonacci_plonk(
mod tests {
use std::env;

use crate::constraint_framework::logup::LookupElements;
use crate::core::air::Component;
use crate::core::channel::Blake2sChannel;
use crate::core::fri::FriConfig;
use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig};
use crate::core::prover::verify;
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
use crate::examples::plonk::prove_fibonacci_plonk;
use crate::examples::plonk::{prove_fibonacci_plonk, PlonkLookupElements};

#[test_log::test]
fn test_simd_plonk_prove() {
Expand Down Expand Up @@ -300,7 +301,7 @@ mod tests {
// Trace columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);
// Draw lookup element.
let lookup_elements = LookupElements::<2>::draw(channel);
let lookup_elements = PlonkLookupElements::draw(channel);
assert_eq!(lookup_elements, component.lookup_elements);
// Interaction columns.
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel);
Expand Down
Loading

0 comments on commit 670e588

Please sign in to comment.