From 7ca5813e69750419944466b2daf1004db6f48c52 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 10 Oct 2023 15:34:56 +0800 Subject: [PATCH] comments from Alex --- Cargo.toml | 1 - .../prf_sharding/feature_label_dot_product.rs | 24 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a72bfcfbf..e264d503f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,7 +95,6 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } typenum = "1.16" # hpke is pinned to it x25519-dalek = "2.0.0-pre.0" -stream-flatten-iters = "0.2.0" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index a326c02ef..0dfc28f3a 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -3,7 +3,6 @@ use std::{iter::zip, pin::pin}; use futures::{stream::iter as stream_iter, TryStreamExt}; use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use ipa_macros::Step; -use stream_flatten_iters::StreamExt as _; use crate::{ error::Error, @@ -171,8 +170,7 @@ where IS: Stream> + Unpin, { unfold(Some((input_stream, first_row)), |state| async move { - state.as_ref()?; - let (mut s, last_row) = state.unwrap(); + let (mut s, last_row) = state?; let last_row_prf = last_row.prf_of_match_key; let mut current_chunk = vec![last_row]; while let Some(row) = s.next().await { @@ -221,7 +219,6 @@ where /// Propagates errors from multiplications /// # Panics /// Propagates errors from multiplications -#[allow(clippy::async_yields_async)] pub async fn compute_feature_label_dot_product( sh_ctx: C, input_rows: Vec>, @@ -252,7 +249,11 @@ where // Chunk the incoming stream of records into stream of vectors of records with the same PRF let mut input_stream = stream_iter(input_rows); - let first_row = input_stream.next().await.unwrap(); + let first_row = input_stream.next().await; + if first_row.is_none() { + return Ok(vec![]); + } + let first_row = first_row.unwrap(); let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); // Convert to a stream of async futures that represent the result of executing the per-user circuit @@ -261,16 +262,19 @@ where let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); - for count in record_id_for_row_depth.iter_mut().take(rows_for_user.len()) { + for count in &mut record_id_for_row_depth[..num_user_rows] { *count += 1; } - async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) } + #[allow(clippy::async_yields_async)] + // this is ok, because seq join wants a stream of futures + async move { + evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) + } })); // Execute all of the async futures (sequentially), and flatten the result let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) - .map(|x| x.unwrap().into_iter()) - .flatten_iters(); + .flat_map(|x| stream_iter(x.unwrap())); // modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p` let converted_feature_vector_bits = convert_bits( @@ -317,7 +321,7 @@ where for (i, (row, ctx)) in zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate() { - let record_id_for_this_row_depth = RecordId(record_id_for_each_depth[i + 1]); // skip row 0 + let record_id_for_this_row_depth = RecordId::from(record_id_for_each_depth[i + 1]); // skip row 0 let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous(ctx, record_id_for_this_row_depth, row)