From fe6397bc12fa65729db2d3a2d3c8984a9f2ea6ab Mon Sep 17 00:00:00 2001 From: GopherJ Date: Tue, 19 Nov 2024 15:52:18 +0800 Subject: [PATCH] migrate to next release of winterfell Signed-off-by: GopherJ --- prover/src/gpu/webgpu/mod.rs | 66 ++++-------------------------------- 1 file changed, 6 insertions(+), 60 deletions(-) diff --git a/prover/src/gpu/webgpu/mod.rs b/prover/src/gpu/webgpu/mod.rs index d74a048b6..4e9e34b25 100644 --- a/prover/src/gpu/webgpu/mod.rs +++ b/prover/src/gpu/webgpu/mod.rs @@ -89,7 +89,7 @@ where } } - fn build_aligned_segement( + fn build_aligned_segment( polys: &ColMatrix, poly_offset: usize, offsets: &[Felt], @@ -119,7 +119,7 @@ where Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles) } - fn build_aligned_segements( + fn build_aligned_segments( polys: &ColMatrix, twiddles: &[Felt], offsets: &[Felt], @@ -138,7 +138,7 @@ where }; (0..num_segments) - .map(|i| Self::build_aligned_segement(polys, i * N, offsets, twiddles)) + .map(|i| Self::build_aligned_segment(polys, i * N, offsets, twiddles)) .collect() } } @@ -241,7 +241,7 @@ where let blowup = domain.trace_to_lde_blowup(); let offsets = get_evaluation_offsets::(composition_poly.column_len(), blowup, domain.offset()); - let segments = Self::build_aligned_segements( + let segments = Self::build_aligned_segments( composition_poly.data(), domain.trace_twiddles(), &offsets, @@ -262,36 +262,9 @@ where let lde_domain_size = domain.lde_domain_size(); let num_base_columns = composition_poly.num_columns() * ::EXTENSION_DEGREE; - let rpo_requires_padding = num_base_columns % RATE != 0; - let is_rpo = self.webgpu_hash_fn == HashFn::Rpo256; - let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE); let mut row_hasher = RowHasher::::new(helper, lde_domain_size, num_base_columns, self.webgpu_hash_fn); - let rpo_padded_segment: Vec<[Felt; RATE]>; - for (segment_idx, segment) in segments.iter().enumerate() { - // check if the segment requires padding - if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) { - // duplicate and modify the last segment with Rpo256's padding - // rule ("1" followed by "0"s). Our segments are already - // padded with "0"s we only need to add the "1"s. - let rpo_pad_column = num_base_columns % RATE; - - rpo_padded_segment = if is_rpo { - segment - .iter() - .map(|x| { - let mut s = *x; - s[rpo_pad_column] = ONE; - s - }) - .collect() - } else { - segment.iter().map(|x| *x).collect() - }; - row_hasher.update(&rpo_padded_segment); - assert_eq!(segments.len() - 1, segment_idx, "padded segment should be the last"); - break; - } + for segment in segments.iter() { row_hasher.update(segment); } let row_hashes = maybe_await!(row_hasher.finish()); @@ -636,37 +609,10 @@ async fn build_trace_commitment< let lde_segments = FrozenVec::new(); let lde_domain_size = domain.lde_domain_size(); let num_base_columns = trace.num_base_cols(); - let rpo_requires_padding = num_base_columns % RATE != 0; - let is_rpo = hash_fn == HashFn::Rpo256; - let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE); let mut row_hasher = RowHasher::::new(&helper, lde_domain_size, num_base_columns, hash_fn); - let rpo_padded_segment: Vec<[Felt; RATE]>; let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain); - let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate(); - for (segment_idx, segment) in &mut lde_segment_iter { + for segment in lde_segment_generator.gen_segment_iter() { let segment = lde_segments.push_get(Box::new(segment)); - // check if the segment requires padding - if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) { - // duplicate and modify the last segment with Rpo256's padding - // rule ("1" followed by "0"s). Our segments are already - // padded with "0"s we only need to add the "1"s. - let rpo_pad_column = num_base_columns % RATE; - rpo_padded_segment = if is_rpo { - segment - .iter() - .map(|x| { - let mut s = *x; - s[rpo_pad_column] = ONE; - s - }) - .collect() - } else { - segment.iter().map(|x| *x).collect() - }; - row_hasher.update(&rpo_padded_segment); - assert!(lde_segment_iter.next().is_none(), "padded segment should be the last"); - break; - } row_hasher.update(segment); } let row_hashes = row_hasher.finish().await;