Skip to content

Commit

Permalink
Separated Xor table types.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 19, 2024
1 parent 77de4d7 commit 5a1d139
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 414 deletions.
102 changes: 56 additions & 46 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{span, Level};

use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval};
use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval};
use super::xor_table::{column_bits, XorTableComponent, XorTableEval};
use super::xor_table::{xor12, xor4, xor7, xor8, xor9};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::{TraceLocationAllocator, PREPROCESSED_TRACE_IDX};
use crate::core::air::{Component, ComponentProver};
Expand All @@ -30,23 +30,23 @@ const PREPROCESSED_XOR_COLUMNS: [PreprocessedColumn; 20] = [
PreprocessedColumn::XorTable(12, 4, 0),
PreprocessedColumn::XorTable(12, 4, 1),
PreprocessedColumn::XorTable(12, 4, 2),
PreprocessedColumn::IsFirst(column_bits::<12, 4>()),
PreprocessedColumn::IsFirst(xor12::column_bits::<12, 4>()),
PreprocessedColumn::XorTable(9, 2, 0),
PreprocessedColumn::XorTable(9, 2, 1),
PreprocessedColumn::XorTable(9, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<9, 2>()),
PreprocessedColumn::IsFirst(xor9::column_bits::<9, 2>()),
PreprocessedColumn::XorTable(8, 2, 0),
PreprocessedColumn::XorTable(8, 2, 1),
PreprocessedColumn::XorTable(8, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<8, 2>()),
PreprocessedColumn::IsFirst(xor8::column_bits::<8, 2>()),
PreprocessedColumn::XorTable(7, 2, 0),
PreprocessedColumn::XorTable(7, 2, 1),
PreprocessedColumn::XorTable(7, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<7, 2>()),
PreprocessedColumn::IsFirst(xor7::column_bits::<7, 2>()),
PreprocessedColumn::XorTable(4, 0, 0),
PreprocessedColumn::XorTable(4, 0, 1),
PreprocessedColumn::XorTable(4, 0, 2),
PreprocessedColumn::IsFirst(column_bits::<4, 0>()),
PreprocessedColumn::IsFirst(xor4::column_bits::<4, 0>()),
];

#[derive(Serialize)]
Expand All @@ -70,11 +70,11 @@ impl BlakeStatement0 {
.map_cols(|_| self.log_size + l),
);
}
sizes.push(xor_table::trace_sizes::<12, 4>());
sizes.push(xor_table::trace_sizes::<9, 2>());
sizes.push(xor_table::trace_sizes::<8, 2>());
sizes.push(xor_table::trace_sizes::<7, 2>());
sizes.push(xor_table::trace_sizes::<4, 0>());
sizes.push(xor_table::xor12::trace_sizes::<12, 4>());
sizes.push(xor_table::xor9::trace_sizes::<9, 2>());
sizes.push(xor_table::xor8::trace_sizes::<8, 2>());
sizes.push(xor_table::xor7::trace_sizes::<7, 2>());
sizes.push(xor_table::xor4::trace_sizes::<4, 0>());

let mut log_sizes = TreeVec::concat_cols(sizes.into_iter());

Expand Down Expand Up @@ -154,11 +154,11 @@ pub struct BlakeProof<H: MerkleHasher> {
pub struct BlakeComponents {
scheduler_component: BlakeSchedulerComponent,
round_components: Vec<BlakeRoundComponent>,
xor12: XorTableComponent<12, 4>,
xor9: XorTableComponent<9, 2>,
xor8: XorTableComponent<8, 2>,
xor7: XorTableComponent<7, 2>,
xor4: XorTableComponent<4, 0>,
xor12: xor12::XorTableComponent<12, 4>,
xor9: xor9::XorTableComponent<9, 2>,
xor8: xor8::XorTableComponent<8, 2>,
xor7: xor7::XorTableComponent<7, 2>,
xor4: xor4::XorTableComponent<4, 0>,
}
impl BlakeComponents {
fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self {
Expand Down Expand Up @@ -205,41 +205,41 @@ impl BlakeComponents {
)
})
.collect(),
xor12: XorTableComponent::new(
xor12: xor12::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor12::XorTableEval {
lookup_elements: all_elements.xor_elements.xor12.clone(),
claimed_sum: stmt1.xor12_claimed_sum,
},
(stmt1.xor12_claimed_sum, None),
),
xor9: XorTableComponent::new(
xor9: xor9::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor9::XorTableEval {
lookup_elements: all_elements.xor_elements.xor9.clone(),
claimed_sum: stmt1.xor9_claimed_sum,
},
(stmt1.xor9_claimed_sum, None),
),
xor8: XorTableComponent::new(
xor8: xor8::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor8::XorTableEval {
lookup_elements: all_elements.xor_elements.xor8.clone(),
claimed_sum: stmt1.xor8_claimed_sum,
},
(stmt1.xor8_claimed_sum, None),
),
xor7: XorTableComponent::new(
xor7: xor7::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor7::XorTableEval {
lookup_elements: all_elements.xor_elements.xor7.clone(),
claimed_sum: stmt1.xor7_claimed_sum,
},
(stmt1.xor7_claimed_sum, None),
),
xor4: XorTableComponent::new(
xor4: xor4::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor4::XorTableEval {
lookup_elements: all_elements.xor_elements.xor4.clone(),
claimed_sum: stmt1.xor4_claimed_sum,
},
Expand Down Expand Up @@ -324,11 +324,11 @@ where
chain![
vec![gen_is_first(log_size)],
ROUND_LOG_SPLIT.iter().map(|l| gen_is_first(log_size + l)),
xor_table::generate_constant_trace::<12, 4>(),
xor_table::generate_constant_trace::<9, 2>(),
xor_table::generate_constant_trace::<8, 2>(),
xor_table::generate_constant_trace::<7, 2>(),
xor_table::generate_constant_trace::<4, 0>(),
xor_table::xor12::generate_constant_trace::<12, 4>(),
xor_table::xor9::generate_constant_trace::<9, 2>(),
xor_table::xor8::generate_constant_trace::<8, 2>(),
xor_table::xor7::generate_constant_trace::<7, 2>(),
xor_table::xor4::generate_constant_trace::<4, 0>(),
]
.collect_vec(),
);
Expand All @@ -353,11 +353,11 @@ where
}));

// Xor tables.
let (xor_trace12, xor_lookup_data12) = xor_table::generate_trace(xor_accums.xor12);
let (xor_trace9, xor_lookup_data9) = xor_table::generate_trace(xor_accums.xor9);
let (xor_trace8, xor_lookup_data8) = xor_table::generate_trace(xor_accums.xor8);
let (xor_trace7, xor_lookup_data7) = xor_table::generate_trace(xor_accums.xor7);
let (xor_trace4, xor_lookup_data4) = xor_table::generate_trace(xor_accums.xor4);
let (xor_trace12, xor_lookup_data12) = xor_table::xor12::generate_trace(xor_accums.xor12);
let (xor_trace9, xor_lookup_data9) = xor_table::xor9::generate_trace(xor_accums.xor9);
let (xor_trace8, xor_lookup_data8) = xor_table::xor8::generate_trace(xor_accums.xor8);
let (xor_trace7, xor_lookup_data7) = xor_table::xor7::generate_trace(xor_accums.xor7);
let (xor_trace4, xor_lookup_data4) = xor_table::xor4::generate_trace(xor_accums.xor4);

// Statement0.
let stmt0 = BlakeStatement0 { log_size };
Expand Down Expand Up @@ -406,16 +406,26 @@ where
}),
);

let (xor_trace12, xor12_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data12, &all_elements.xor_elements.xor12);
let (xor_trace9, xor9_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data9, &all_elements.xor_elements.xor9);
let (xor_trace8, xor8_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data8, &all_elements.xor_elements.xor8);
let (xor_trace7, xor7_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data7, &all_elements.xor_elements.xor7);
let (xor_trace4, xor4_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data4, &all_elements.xor_elements.xor4);
let (xor_trace12, xor12_claimed_sum) = xor_table::xor12::generate_interaction_trace(
xor_lookup_data12,
&all_elements.xor_elements.xor12,
);
let (xor_trace9, xor9_claimed_sum) = xor_table::xor9::generate_interaction_trace(
xor_lookup_data9,
&all_elements.xor_elements.xor9,
);
let (xor_trace8, xor8_claimed_sum) = xor_table::xor8::generate_interaction_trace(
xor_lookup_data8,
&all_elements.xor_elements.xor8,
);
let (xor_trace7, xor7_claimed_sum) = xor_table::xor7::generate_interaction_trace(
xor_lookup_data7,
&all_elements.xor_elements.xor7,
);
let (xor_trace4, xor4_claimed_sum) = xor_table::xor4::generate_interaction_trace(
xor_lookup_data4,
&all_elements.xor_elements.xor4,
);

let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(
Expand Down
92 changes: 64 additions & 28 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Sub};
use std::simd::u32x16;

use xor_table::{XorAccumulator, XorElements};
use num_traits::One;
use xor_table::{xor12, xor4, xor7, xor8, xor9};

use crate::constraint_framework::{relation, EvalAtRow, Relation, RelationEntry};
use crate::core::backend::simd::m31::PackedBaseField;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::FieldExpOps;
Expand All @@ -29,11 +32,11 @@ const ROUND_LOG_SPLIT: [u32; 2] = [3, 1];

#[derive(Default)]
struct XorAccums {
xor12: XorAccumulator<12, 4>,
xor9: XorAccumulator<9, 2>,
xor8: XorAccumulator<8, 2>,
xor7: XorAccumulator<7, 2>,
xor4: XorAccumulator<4, 0>,
xor12: xor12::XorAccumulator<12, 4>,
xor9: xor9::XorAccumulator<9, 2>,
xor8: xor8::XorAccumulator<8, 2>,
xor7: xor7::XorAccumulator<7, 2>,
xor4: xor4::XorAccumulator<4, 0>,
}
impl XorAccums {
fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) {
Expand All @@ -48,41 +51,74 @@ impl XorAccums {
}
}

// TODO(alont): Get these out of the struct and give them names.
relation!(XorElements12, 3);
relation!(XorElements9, 3);
relation!(XorElements8, 3);
relation!(XorElements7, 3);
relation!(XorElements4, 3);

#[derive(Clone)]
pub struct BlakeXorElements {
xor12: XorElements,
xor9: XorElements,
xor8: XorElements,
xor7: XorElements,
xor4: XorElements,
xor12: XorElements12,
xor9: XorElements9,
xor8: XorElements8,
xor7: XorElements7,
xor4: XorElements4,
}
impl BlakeXorElements {
fn draw(channel: &mut impl Channel) -> Self {
Self {
xor12: XorElements::draw(channel),
xor9: XorElements::draw(channel),
xor8: XorElements::draw(channel),
xor7: XorElements::draw(channel),
xor4: XorElements::draw(channel),
xor12: XorElements12::draw(channel),
xor9: XorElements9::draw(channel),
xor8: XorElements8::draw(channel),
xor7: XorElements7::draw(channel),
xor4: XorElements4::draw(channel),
}
}
fn dummy() -> Self {
Self {
xor12: XorElements::dummy(),
xor9: XorElements::dummy(),
xor8: XorElements::dummy(),
xor7: XorElements::dummy(),
xor4: XorElements::dummy(),
xor12: XorElements12::dummy(),
xor9: XorElements9::dummy(),
xor8: XorElements8::dummy(),
xor7: XorElements7::dummy(),
xor4: XorElements4::dummy(),
}
}
fn get(&self, w: u32) -> &XorElements {

// TODO(alont): Generalize this to variable sizes batches if ever used.
fn use_relation<E: EvalAtRow>(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) {
match w {
12 => eval.add_to_relation(&[
RelationEntry::new(&self.xor12, E::EF::one(), values[0]),
RelationEntry::new(&self.xor12, E::EF::one(), values[1]),
]),
9 => eval.add_to_relation(&[
RelationEntry::new(&self.xor9, E::EF::one(), values[0]),
RelationEntry::new(&self.xor9, E::EF::one(), values[1]),
]),
8 => eval.add_to_relation(&[
RelationEntry::new(&self.xor8, E::EF::one(), values[0]),
RelationEntry::new(&self.xor8, E::EF::one(), values[1]),
]),
7 => eval.add_to_relation(&[
RelationEntry::new(&self.xor7, E::EF::one(), values[0]),
RelationEntry::new(&self.xor7, E::EF::one(), values[1]),
]),
4 => eval.add_to_relation(&[
RelationEntry::new(&self.xor4, E::EF::one(), values[0]),
RelationEntry::new(&self.xor4, E::EF::one(), values[1]),
]),
_ => panic!("Invalid w"),
};
}

fn combine(&self, w: u32, values: &[PackedBaseField]) -> PackedSecureField {
match w {
12 => &self.xor12,
9 => &self.xor9,
8 => &self.xor8,
7 => &self.xor7,
4 => &self.xor4,
12 => self.xor12.combine(values),
9 => self.xor9.combine(values),
8 => self.xor8.combine(values),
7 => self.xor7.combine(values),
4 => self.xor4.combine(values),
_ => panic!("Invalid w"),
}
}
Expand Down
19 changes: 11 additions & 8 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::{EvalAtRow, RelationEntry};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Reciprocal;
use crate::examples::blake::{Fu32, STATE_SIZE};

const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15);
Expand Down Expand Up @@ -189,14 +188,18 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
fn xor2(&mut self, w: u32, a: [E::F; 2], b: [E::F; 2]) -> [E::F; 2] {
// TODO: Separate lookups by w.
let c = [self.eval.next_trace_mask(), self.eval.next_trace_mask()];
let lookup_elements = self.xor_lookup_elements.get(w);
let comb0 =
lookup_elements.combine::<E::F, E::EF>(&[a[0].clone(), b[0].clone(), c[0].clone()]);
let comb1 =
lookup_elements.combine::<E::F, E::EF>(&[a[1].clone(), b[1].clone(), c[1].clone()]);

self.eval
.write_logup_frac(Reciprocal::new(comb0) + Reciprocal::new(comb1));
let xor_lookup_elements = self.xor_lookup_elements;

xor_lookup_elements.use_relation(
&mut self.eval,
w,
[
&[a[0].clone(), b[0].clone(), c[0].clone()],
&[a[1].clone(), b[1].clone(), c[1].clone()],
],
);

c
}
}
10 changes: 4 additions & 6 deletions crates/prover/src/examples/blake/round/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,10 @@ pub fn generate_interaction_trace(

#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let p0: PackedSecureField = xor_lookup_elements
.get(*w0)
.combine(&l0.each_ref().map(|l| l.data[vec_row]));
let p1: PackedSecureField = xor_lookup_elements
.get(*w1)
.combine(&l1.each_ref().map(|l| l.data[vec_row]));
let p0: PackedSecureField =
xor_lookup_elements.combine(*w0, &l0.each_ref().map(|l| l.data[vec_row]));
let p1: PackedSecureField =
xor_lookup_elements.combine(*w1, &l1.each_ref().map(|l| l.data[vec_row]));
col_gen.write_frac(vec_row, p0 + p1, p0 * p1);
}

Expand Down
Loading

0 comments on commit 5a1d139

Please sign in to comment.