Skip to content

Commit

Permalink
migrate to next release of winterfell
Browse files Browse the repository at this point in the history
Signed-off-by: GopherJ <[email protected]>
  • Loading branch information
GopherJ committed Nov 19, 2024
1 parent dee76c6 commit fe6397b
Showing 1 changed file with 6 additions and 60 deletions.
66 changes: 6 additions & 60 deletions prover/src/gpu/webgpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
}
}

fn build_aligned_segement<E, const N: usize>(
fn build_aligned_segment<E, const N: usize>(
polys: &ColMatrix<E>,
poly_offset: usize,
offsets: &[Felt],
Expand Down Expand Up @@ -119,7 +119,7 @@ where
Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
}

fn build_aligned_segements<E, const N: usize>(
fn build_aligned_segments<E, const N: usize>(
polys: &ColMatrix<E>,
twiddles: &[Felt],
offsets: &[Felt],
Expand All @@ -138,7 +138,7 @@ where
};

(0..num_segments)
.map(|i| Self::build_aligned_segement(polys, i * N, offsets, twiddles))
.map(|i| Self::build_aligned_segment(polys, i * N, offsets, twiddles))
.collect()
}
}
Expand Down Expand Up @@ -241,7 +241,7 @@ where
let blowup = domain.trace_to_lde_blowup();
let offsets =
get_evaluation_offsets::<E>(composition_poly.column_len(), blowup, domain.offset());
let segments = Self::build_aligned_segements(
let segments = Self::build_aligned_segments(
composition_poly.data(),
domain.trace_twiddles(),
&offsets,
Expand All @@ -262,36 +262,9 @@ where
let lde_domain_size = domain.lde_domain_size();
let num_base_columns =
composition_poly.num_columns() * <E as FieldElement>::EXTENSION_DEGREE;
let rpo_requires_padding = num_base_columns % RATE != 0;
let is_rpo = self.webgpu_hash_fn == HashFn::Rpo256;
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher =
RowHasher::<H>::new(helper, lde_domain_size, num_base_columns, self.webgpu_hash_fn);
let rpo_padded_segment: Vec<[Felt; RATE]>;
for (segment_idx, segment) in segments.iter().enumerate() {
// check if the segment requires padding
if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) {
// duplicate and modify the last segment with Rpo256's padding
// rule ("1" followed by "0"s). Our segments are already
// padded with "0"s we only need to add the "1"s.
let rpo_pad_column = num_base_columns % RATE;

rpo_padded_segment = if is_rpo {
segment
.iter()
.map(|x| {
let mut s = *x;
s[rpo_pad_column] = ONE;
s
})
.collect()
} else {
segment.iter().map(|x| *x).collect()
};
row_hasher.update(&rpo_padded_segment);
assert_eq!(segments.len() - 1, segment_idx, "padded segment should be the last");
break;
}
for segment in segments.iter() {
row_hasher.update(segment);
}
let row_hashes = maybe_await!(row_hasher.finish());
Expand Down Expand Up @@ -636,37 +609,10 @@ async fn build_trace_commitment<
let lde_segments = FrozenVec::new();
let lde_domain_size = domain.lde_domain_size();
let num_base_columns = trace.num_base_cols();
let rpo_requires_padding = num_base_columns % RATE != 0;
let is_rpo = hash_fn == HashFn::Rpo256;
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher = RowHasher::<H>::new(&helper, lde_domain_size, num_base_columns, hash_fn);
let rpo_padded_segment: Vec<[Felt; RATE]>;
let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain);
let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate();
for (segment_idx, segment) in &mut lde_segment_iter {
for segment in lde_segment_generator.gen_segment_iter() {
let segment = lde_segments.push_get(Box::new(segment));
// check if the segment requires padding
if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) {
// duplicate and modify the last segment with Rpo256's padding
// rule ("1" followed by "0"s). Our segments are already
// padded with "0"s we only need to add the "1"s.
let rpo_pad_column = num_base_columns % RATE;
rpo_padded_segment = if is_rpo {
segment
.iter()
.map(|x| {
let mut s = *x;
s[rpo_pad_column] = ONE;
s
})
.collect()
} else {
segment.iter().map(|x| *x).collect()
};
row_hasher.update(&rpo_padded_segment);
assert!(lde_segment_iter.next().is_none(), "padded segment should be the last");
break;
}
row_hasher.update(segment);
}
let row_hashes = row_hasher.finish().await;
Expand Down

0 comments on commit fe6397b

Please sign in to comment.