Skip to content

Commit

Permalink
Separated Xor table types.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 19, 2024
1 parent 77de4d7 commit 89bac9f
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 414 deletions.
78 changes: 78 additions & 0 deletions \
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use itertools::Itertools;

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
use crate::examples::blake::{
XorElements12, XorElements4, XorElements7, XorElements8, XorElements9,
};

#[macro_export]
macro_rules! xor_table_eval {
($modname:tt, $elements:tt, $elem_bits:literal, $expand_bits:literal) => {
pub mod $modname {
use super::*;
use $crate::examples::blake::xor_table::$modname::limb_bits;

/// Constraints for the xor table.
pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> {
pub eval: E,
pub lookup_elements: &'a $elements,
pub claimed_sum: SecureField,
pub log_size: u32,
}
impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS>
{
pub fn eval(mut self) -> E {
// al, bl are the constant columns for the inputs: All pairs of elements in [0,
// 2^LIMB_BITS).
// cl is the constant column for the xor: al ^ bl.
let al = self
.eval
.get_preprocessed_column(PreprocessedColumn::XorTable(ELEM_BITS, EXPAND_BITS, 0));

let bl = self
.eval
.get_preprocessed_column(PreprocessedColumn::XorTable(ELEM_BITS, EXPAND_BITS, 1));

let cl = self
.eval
.get_preprocessed_column(PreprocessedColumn::XorTable(ELEM_BITS, EXPAND_BITS, 2));

let entry_chunks = (0..(1 << (2 * EXPAND_BITS)))
.map(|i| {
let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32);
let multiplicity = self.eval.next_trace_mask();

let a = al.clone()
+ E::F::from(BaseField::from_u32_unchecked(
i << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));
let b = bl.clone()
+ E::F::from(BaseField::from_u32_unchecked(
j << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));
let c = cl.clone()
+ E::F::from(BaseField::from_u32_unchecked(
(i ^ j) << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));

(self.lookup_elements, -multiplicity, [a, b, c])
})
.collect_vec();

for entry_chunk in entry_chunks.chunks(2) {
self.eval.add_to_relation(&[
RelationEntry::new(entry_chunk[0].0, entry_chunk[0].1.clone().into(), &entry_chunk[0].2),
RelationEntry::new(entry_chunk[1].0, entry_chunk[1].1.clone().into(), &entry_chunk[1].2),
]);
}
self.eval.finalize_logup();
self.eval
}
}
}}}
102 changes: 56 additions & 46 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{span, Level};

use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval};
use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval};
use super::xor_table::{column_bits, XorTableComponent, XorTableEval};
use super::xor_table::{xor12, xor4, xor7, xor8, xor9};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::{TraceLocationAllocator, PREPROCESSED_TRACE_IDX};
use crate::core::air::{Component, ComponentProver};
Expand All @@ -30,23 +30,23 @@ const PREPROCESSED_XOR_COLUMNS: [PreprocessedColumn; 20] = [
PreprocessedColumn::XorTable(12, 4, 0),
PreprocessedColumn::XorTable(12, 4, 1),
PreprocessedColumn::XorTable(12, 4, 2),
PreprocessedColumn::IsFirst(column_bits::<12, 4>()),
PreprocessedColumn::IsFirst(xor12::column_bits::<12, 4>()),
PreprocessedColumn::XorTable(9, 2, 0),
PreprocessedColumn::XorTable(9, 2, 1),
PreprocessedColumn::XorTable(9, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<9, 2>()),
PreprocessedColumn::IsFirst(xor9::column_bits::<9, 2>()),
PreprocessedColumn::XorTable(8, 2, 0),
PreprocessedColumn::XorTable(8, 2, 1),
PreprocessedColumn::XorTable(8, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<8, 2>()),
PreprocessedColumn::IsFirst(xor8::column_bits::<8, 2>()),
PreprocessedColumn::XorTable(7, 2, 0),
PreprocessedColumn::XorTable(7, 2, 1),
PreprocessedColumn::XorTable(7, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<7, 2>()),
PreprocessedColumn::IsFirst(xor7::column_bits::<7, 2>()),
PreprocessedColumn::XorTable(4, 0, 0),
PreprocessedColumn::XorTable(4, 0, 1),
PreprocessedColumn::XorTable(4, 0, 2),
PreprocessedColumn::IsFirst(column_bits::<4, 0>()),
PreprocessedColumn::IsFirst(xor4::column_bits::<4, 0>()),
];

#[derive(Serialize)]
Expand All @@ -70,11 +70,11 @@ impl BlakeStatement0 {
.map_cols(|_| self.log_size + l),
);
}
sizes.push(xor_table::trace_sizes::<12, 4>());
sizes.push(xor_table::trace_sizes::<9, 2>());
sizes.push(xor_table::trace_sizes::<8, 2>());
sizes.push(xor_table::trace_sizes::<7, 2>());
sizes.push(xor_table::trace_sizes::<4, 0>());
sizes.push(xor_table::xor12::trace_sizes::<12, 4>());
sizes.push(xor_table::xor9::trace_sizes::<9, 2>());
sizes.push(xor_table::xor8::trace_sizes::<8, 2>());
sizes.push(xor_table::xor7::trace_sizes::<7, 2>());
sizes.push(xor_table::xor4::trace_sizes::<4, 0>());

let mut log_sizes = TreeVec::concat_cols(sizes.into_iter());

Expand Down Expand Up @@ -154,11 +154,11 @@ pub struct BlakeProof<H: MerkleHasher> {
pub struct BlakeComponents {
scheduler_component: BlakeSchedulerComponent,
round_components: Vec<BlakeRoundComponent>,
xor12: XorTableComponent<12, 4>,
xor9: XorTableComponent<9, 2>,
xor8: XorTableComponent<8, 2>,
xor7: XorTableComponent<7, 2>,
xor4: XorTableComponent<4, 0>,
xor12: xor12::XorTableComponent<12, 4>,
xor9: xor9::XorTableComponent<9, 2>,
xor8: xor8::XorTableComponent<8, 2>,
xor7: xor7::XorTableComponent<7, 2>,
xor4: xor4::XorTableComponent<4, 0>,
}
impl BlakeComponents {
fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self {
Expand Down Expand Up @@ -205,41 +205,41 @@ impl BlakeComponents {
)
})
.collect(),
xor12: XorTableComponent::new(
xor12: xor12::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor12::XorTableEval {
lookup_elements: all_elements.xor_elements.xor12.clone(),
claimed_sum: stmt1.xor12_claimed_sum,
},
(stmt1.xor12_claimed_sum, None),
),
xor9: XorTableComponent::new(
xor9: xor9::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor9::XorTableEval {
lookup_elements: all_elements.xor_elements.xor9.clone(),
claimed_sum: stmt1.xor9_claimed_sum,
},
(stmt1.xor9_claimed_sum, None),
),
xor8: XorTableComponent::new(
xor8: xor8::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor8::XorTableEval {
lookup_elements: all_elements.xor_elements.xor8.clone(),
claimed_sum: stmt1.xor8_claimed_sum,
},
(stmt1.xor8_claimed_sum, None),
),
xor7: XorTableComponent::new(
xor7: xor7::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor7::XorTableEval {
lookup_elements: all_elements.xor_elements.xor7.clone(),
claimed_sum: stmt1.xor7_claimed_sum,
},
(stmt1.xor7_claimed_sum, None),
),
xor4: XorTableComponent::new(
xor4: xor4::XorTableComponent::new(
tree_span_provider,
XorTableEval {
xor4::XorTableEval {
lookup_elements: all_elements.xor_elements.xor4.clone(),
claimed_sum: stmt1.xor4_claimed_sum,
},
Expand Down Expand Up @@ -324,11 +324,11 @@ where
chain![
vec![gen_is_first(log_size)],
ROUND_LOG_SPLIT.iter().map(|l| gen_is_first(log_size + l)),
xor_table::generate_constant_trace::<12, 4>(),
xor_table::generate_constant_trace::<9, 2>(),
xor_table::generate_constant_trace::<8, 2>(),
xor_table::generate_constant_trace::<7, 2>(),
xor_table::generate_constant_trace::<4, 0>(),
xor_table::xor12::generate_constant_trace::<12, 4>(),
xor_table::xor9::generate_constant_trace::<9, 2>(),
xor_table::xor8::generate_constant_trace::<8, 2>(),
xor_table::xor7::generate_constant_trace::<7, 2>(),
xor_table::xor4::generate_constant_trace::<4, 0>(),
]
.collect_vec(),
);
Expand All @@ -353,11 +353,11 @@ where
}));

// Xor tables.
let (xor_trace12, xor_lookup_data12) = xor_table::generate_trace(xor_accums.xor12);
let (xor_trace9, xor_lookup_data9) = xor_table::generate_trace(xor_accums.xor9);
let (xor_trace8, xor_lookup_data8) = xor_table::generate_trace(xor_accums.xor8);
let (xor_trace7, xor_lookup_data7) = xor_table::generate_trace(xor_accums.xor7);
let (xor_trace4, xor_lookup_data4) = xor_table::generate_trace(xor_accums.xor4);
let (xor_trace12, xor_lookup_data12) = xor_table::xor12::generate_trace(xor_accums.xor12);
let (xor_trace9, xor_lookup_data9) = xor_table::xor9::generate_trace(xor_accums.xor9);
let (xor_trace8, xor_lookup_data8) = xor_table::xor8::generate_trace(xor_accums.xor8);
let (xor_trace7, xor_lookup_data7) = xor_table::xor7::generate_trace(xor_accums.xor7);
let (xor_trace4, xor_lookup_data4) = xor_table::xor4::generate_trace(xor_accums.xor4);

// Statement0.
let stmt0 = BlakeStatement0 { log_size };
Expand Down Expand Up @@ -406,16 +406,26 @@ where
}),
);

let (xor_trace12, xor12_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data12, &all_elements.xor_elements.xor12);
let (xor_trace9, xor9_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data9, &all_elements.xor_elements.xor9);
let (xor_trace8, xor8_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data8, &all_elements.xor_elements.xor8);
let (xor_trace7, xor7_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data7, &all_elements.xor_elements.xor7);
let (xor_trace4, xor4_claimed_sum) =
xor_table::generate_interaction_trace(xor_lookup_data4, &all_elements.xor_elements.xor4);
let (xor_trace12, xor12_claimed_sum) = xor_table::xor12::generate_interaction_trace(
xor_lookup_data12,
&all_elements.xor_elements.xor12,
);
let (xor_trace9, xor9_claimed_sum) = xor_table::xor9::generate_interaction_trace(
xor_lookup_data9,
&all_elements.xor_elements.xor9,
);
let (xor_trace8, xor8_claimed_sum) = xor_table::xor8::generate_interaction_trace(
xor_lookup_data8,
&all_elements.xor_elements.xor8,
);
let (xor_trace7, xor7_claimed_sum) = xor_table::xor7::generate_interaction_trace(
xor_lookup_data7,
&all_elements.xor_elements.xor7,
);
let (xor_trace4, xor4_claimed_sum) = xor_table::xor4::generate_interaction_trace(
xor_lookup_data4,
&all_elements.xor_elements.xor4,
);

let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(
Expand Down
Loading

0 comments on commit 89bac9f

Please sign in to comment.