Skip to content

Commit

Permalink
fix: update metal codebase to adjust it to 0.9 version of crates (#1357)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMenko authored Jun 20, 2024
1 parent 978c142 commit 0c19d50
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added error codes support for the `mtree_verify` instruction (#1328).
- Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346).
- Change MAST to a table-based representation (#1349)
- Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357)

## 0.9.2 (2024-05-22) - `stdlib` crate only
- Skip writing MASM documentation to file when building on docs.rs (#1341).
Expand Down
2 changes: 1 addition & 1 deletion air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use vm_core::{
utils::{DeserializationError, ToElements},
Felt, FieldElement, StarkField,
};
pub use winter_air::{AuxRandElements, FieldExtension};
pub use winter_air::{AuxRandElements, FieldExtension, LagrangeKernelEvaluationFrame};

// PROCESSOR AIR
// ================================================================================================
Expand Down
4 changes: 3 additions & 1 deletion miden/src/examples/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::{Example, ONE, ZERO};
use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs};
use miden_vm::{
math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs,

Check warning on line 3 in miden/src/examples/fibonacci.rs

View workflow job for this annotation

GitHub Actions / Check Rust nightly on ubuntu with --all-targets --all-features

unused import: `ProvingOptions`

Check warning on line 3 in miden/src/examples/fibonacci.rs

View workflow job for this annotation

GitHub Actions / Check Rust stable on ubuntu with --all-targets --all-features

unused import: `ProvingOptions`
};

// EXAMPLE BUILDER
// ================================================================================================
Expand Down
1 change: 0 additions & 1 deletion miden/src/examples/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ where
}
}


#[cfg(test)]
pub fn test_example<H>(example: Example<H>, fail: bool)
where
Expand Down
2 changes: 1 addition & 1 deletion prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ winter-prover = { package = "winter-prover", version = "0.9", default-features =

[target.'cfg(all(target_arch = "aarch64", target_os = "macos"))'.dependencies]
elsa = { version = "1.9", optional = true }
miden-gpu = { version = "0.1", optional = true }
miden-gpu = { version = "0.2", optional = true }
pollster = { version = "0.3", optional = true }
97 changes: 64 additions & 33 deletions prover/src/gpu/metal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,26 @@ use crate::{
WinterProofOptions,
};

use air::{AuxRandElements, LagrangeKernelEvaluationFrame};
use elsa::FrozenVec;
use miden_gpu::{
metal::{build_merkle_tree, utils::page_aligned_uninit_vector, RowHasher},
HashFn,
};
use pollster::block_on;
use processor::{utils::group_vector_elements, ONE};
use std::time::Instant;
use processor::{
crypto::{ElementHasher, Hasher},
ONE,
};
use std::{boxed::Box, marker::PhantomData, time::Instant, vec::Vec};
use tracing::{event, Level};
use winter_prover::{
crypto::{Digest, MerkleTree},
matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment},
proof::Queries,
AuxTraceRandElements, CompositionPoly, CompositionPolyTrace, ConstraintCommitment,
ConstraintCompositionCoefficients, DefaultConstraintEvaluator, EvaluationFrame, Prover,
StarkDomain, TraceInfo, TraceLayout, TraceLde, TracePolyTable,
CompositionPoly, CompositionPolyTrace, ConstraintCommitment, ConstraintCompositionCoefficients,
DefaultConstraintEvaluator, EvaluationFrame, Prover, StarkDomain, TraceInfo, TraceLde,
TracePolyTable,
};

#[cfg(test)]
Expand All @@ -44,7 +49,7 @@ pub(crate) struct MetalExecutionProver<H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
pub execution_prover: ExecutionProver<H, R>,
pub metal_hash_fn: HashFn,
Expand All @@ -55,7 +60,7 @@ impl<H, D, R> MetalExecutionProver<H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn) -> Self {
MetalExecutionProver {
Expand Down Expand Up @@ -88,7 +93,7 @@ where
} else {
// but if some columns in the segment will remain unfilled, we allocate memory initialized
// to zeros to make sure we don't end up with memory with undefined values
group_vector_elements(Felt::zeroed_vector(N * domain_size))
vec![[E::BaseField::ZERO; N]; domain_size]
};

Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
Expand Down Expand Up @@ -122,7 +127,7 @@ impl<H, D, R> Prover for MetalExecutionProver<H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
type BaseField = Felt;
type Air = ProcessorAir;
Expand Down Expand Up @@ -153,7 +158,7 @@ where
fn new_evaluator<'a, E: FieldElement<BaseField = Felt>>(
&self,
air: &'a ProcessorAir,
aux_rand_elements: AuxTraceRandElements<E>,
aux_rand_elements: Option<AuxRandElements<E>>,
composition_coefficients: ConstraintCompositionCoefficients<E>,
) -> Self::ConstraintEvaluator<'a, E> {
self.execution_prover
Expand Down Expand Up @@ -197,7 +202,11 @@ where
let blowup = domain.trace_to_lde_blowup();
let offsets =
get_evaluation_offsets::<E>(composition_poly.column_len(), blowup, domain.offset());
let segments = Self::build_aligned_segements(composition_poly.data(), domain.trace_twiddles(), &offsets);
let segments = Self::build_aligned_segements(
composition_poly.data(),
domain.trace_twiddles(),
&offsets,
);
event!(
Level::INFO,
"Evaluated {} composition polynomial columns over LDE domain (2^{} elements) in {} ms",
Expand Down Expand Up @@ -269,9 +278,9 @@ pub struct MetalTraceLde<E: FieldElement<BaseField = Felt>, H: Hasher> {
// commitment to the main segment of the trace
main_segment_tree: MerkleTree<H>,
// low-degree extensions of the auxiliary segments of the trace
aux_segment_ldes: Vec<RowMatrix<E>>,
aux_segment_lde: Option<RowMatrix<E>>,
// commitment to the auxiliary segments of the trace
aux_segment_trees: Vec<MerkleTree<H>>,
aux_segment_tree: Option<MerkleTree<H>>,
blowup: usize,
trace_info: TraceInfo,
metal_hash_fn: HashFn,
Expand Down Expand Up @@ -304,8 +313,8 @@ impl<
let trace_lde = MetalTraceLde {
main_segment_lde,
main_segment_tree,
aux_segment_ldes: Vec::new(),
aux_segment_trees: Vec::new(),
aux_segment_lde: None,
aux_segment_tree: None,
blowup: domain.trace_to_lde_blowup(),
trace_info: trace_info.clone(),
metal_hash_fn,
Expand Down Expand Up @@ -367,7 +376,7 @@ impl<
/// This function will panic if any of the following are true:
/// - the number of rows in the provided `aux_trace` does not match the main trace.
/// - this segment would exceed the number of segments specified by the trace layout.
fn add_aux_segment(
fn set_aux_trace(
&mut self,
aux_trace: &ColMatrix<E>,
domain: &StarkDomain<Felt>,
Expand All @@ -376,21 +385,16 @@ impl<
let (aux_segment_lde, aux_segment_tree, aux_segment_polys) =
build_trace_commitment::<E, H, D>(aux_trace, domain, self.metal_hash_fn);

// check errors
assert!(
self.aux_segment_ldes.len() < self.trace_info.layout().num_aux_segments(),
"the specified number of auxiliary segments has already been added"
);
assert_eq!(
self.main_segment_lde.num_rows(),
aux_segment_lde.num_rows(),
"the number of rows in the auxiliary segment must be the same as in the main segment"
);

// save the lde and commitment
self.aux_segment_ldes.push(aux_segment_lde);
self.aux_segment_lde = Some(aux_segment_lde);
let root_hash = *aux_segment_tree.root();
self.aux_segment_trees.push(aux_segment_tree);
self.aux_segment_tree = Some(aux_segment_tree);

(aux_segment_polys, root_hash)
}
Expand All @@ -415,9 +419,10 @@ impl<
let next_lde_step = (lde_step + self.blowup()) % self.trace_len();

// copy auxiliary trace segment values into the frame
let segment = &self.aux_segment_ldes[0];
frame.current_mut().copy_from_slice(segment.row(lde_step));
frame.next_mut().copy_from_slice(segment.row(next_lde_step));
self.aux_segment_lde.as_ref().map(|mat| {
frame.current_mut().copy_from_slice(mat.row(lde_step));
frame.next_mut().copy_from_slice(mat.row(next_lde_step));
});
}

/// Returns trace table rows at the specified positions along with Merkle authentication paths
Expand All @@ -430,10 +435,10 @@ impl<
positions,
)];

// build queries for auxiliary trace segments
for (i, segment_tree) in self.aux_segment_trees.iter().enumerate() {
let segment_lde = &self.aux_segment_ldes[i];
result.push(build_segment_queries(segment_lde, segment_tree, positions));
if let (Some(aux_segment_lde), Some(aux_segment_tree)) =
(&self.aux_segment_lde, &self.aux_segment_tree)
{
result.push(build_segment_queries(aux_segment_lde, aux_segment_tree, positions));
}

result
Expand All @@ -449,9 +454,35 @@ impl<
self.blowup
}

/// Returns the trace layout of the execution trace.
fn trace_layout(&self) -> &TraceLayout {
self.trace_info.layout()
/// Populates the provided Lagrange kernel frame starting at the current row (as defined by
/// lde_step).
/// Note that unlike EvaluationFrame, the Lagrange kernel frame includes only the Lagrange
/// kernel column (as opposed to all columns).
fn read_lagrange_kernel_frame_into(
&self,
lde_step: usize,
col_idx: usize,
frame: &mut LagrangeKernelEvaluationFrame<E>,
) {
self.aux_segment_lde.as_ref().map(|aux_segment| {
let frame = frame.frame_mut();
frame.truncate(0);

frame.push(aux_segment.get(col_idx, lde_step));

let frame_length = self.trace_info.length().ilog2() as usize + 1;
for i in 0..frame_length - 1 {
let shift = self.blowup() * (1 << i);
let next_lde_step = (lde_step + shift) % self.trace_len();

frame.push(aux_segment.get(col_idx, next_lde_step));
}
});
}

/// Returns the trace info
fn trace_info(&self) -> &TraceInfo {
&self.trace_info
}
}

Expand Down
10 changes: 5 additions & 5 deletions prover/src/gpu/metal/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use winter_prover::{crypto::Digest, math::fields::CubeExtension, CompositionPoly
type CubeFelt = CubeExtension<Felt>;

fn build_trace_commitment_on_gpu_with_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
>(
Expand Down Expand Up @@ -43,7 +43,7 @@ fn build_trace_commitment_on_gpu_with_padding_matches_cpu<
}

fn build_trace_commitment_on_gpu_without_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
>(
Expand Down Expand Up @@ -74,7 +74,7 @@ fn build_trace_commitment_on_gpu_without_padding_matches_cpu<
}

fn build_constraint_commitment_on_gpu_with_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
>(
Expand Down Expand Up @@ -103,7 +103,7 @@ fn build_constraint_commitment_on_gpu_with_padding_matches_cpu<
}

fn build_constraint_commitment_on_gpu_without_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
>(
Expand Down Expand Up @@ -204,7 +204,7 @@ fn get_trace_info(num_cols: usize, num_rows: usize) -> TraceInfo {
}

fn create_test_prover<
R: RandomCoin<BaseField = Felt, Hasher = H>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt>,
>(
use_rpx: bool,
Expand Down

0 comments on commit 0c19d50

Please sign in to comment.