From ddbf2b1f8be18210c045b6862e2e348aad8e705d Mon Sep 17 00:00:00 2001 From: 0xKitsune <0xKitsune@protonmail.com> Date: Wed, 31 Jan 2024 10:04:18 -0500 Subject: [PATCH] updated batch process shares --- src/coordinator.rs | 175 +++++++++++++++++++-------------------------- 1 file changed, 74 insertions(+), 101 deletions(-) diff --git a/src/coordinator.rs b/src/coordinator.rs index 01ee038..e5cce81 100644 --- a/src/coordinator.rs +++ b/src/coordinator.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::sync::mpsc::Receiver; -use tokio::sync::{self, mpsc}; +use tokio::sync::{self, mpsc, Mutex}; use tokio::task::JoinHandle; use crate::bits::Bits; @@ -27,7 +27,7 @@ pub struct Coordinator { aws_client: aws_sdk_sqs::Client, shares_queue_url: String, distances_queue_url: String, - participants: Vec>, + participants: Arc>>>, } impl Coordinator { @@ -54,7 +54,7 @@ impl Coordinator { aws_client, shares_queue_url: shares_queue_url.to_string(), distances_queue_url: distances_queue_url.to_string(), - participants: streams, + participants: Arc::new(Mutex::new(streams)), }) } @@ -80,10 +80,10 @@ impl Coordinator { handles.push(denominator_handle); //TODO: process_participant shares / collect batches of shares and denoms - let (processed_shares_rx, process_shares_handle) = - self.process_participant_shares(denominator_rx); + let (batch_process_shares_rx, batch_process_shares_handle) = + self.batch_process_participant_shares(denominator_rx); - handles.push(process_shares_handle); + handles.push(batch_process_shares_handle); //TODO: process results Handle each that comes through and calc the min distance and min index @@ -102,7 +102,7 @@ impl Coordinator { query: &Template, ) -> eyre::Result<()> { // Write each share to the corresponding participant - future::try_join_all(self.participants.iter_mut().map( + future::try_join_all(self.participants.lock().await.iter_mut().map( |stream| async move { // Send query stream.write_all(bytemuck::bytes_of(query)).await @@ -132,107 +132,80 @@ impl Coordinator { (denom_receiver, denominator_handle) } - pub fn process_participant_shares( + pub fn batch_process_participant_shares( &mut self, - denominator_rx: Receiver>, + mut denominator_rx: Receiver>, ) -> ( Receiver<(Vec<[u16; 31]>, Vec>)>, JoinHandle>, ) { // Collect batches of shares - // let (processed_shares_tx, mut processed_shares_rx) = mpsc::channel(4); - - let streams_future = - future::try_join_all(self.participants.iter_mut().enumerate().map( - |(i, stream)| async move { - let mut batch = vec![[0_u16; 31]; BATCH_SIZE]; - let mut buffer: &mut [u8] = - bytemuck::cast_slice_mut(batch.as_mut_slice()); - - // We can not use read_exact here as we might get EOF before the - // buffer is full But we should - // still try to fill the entire buffer. - // If nothing else, this guarantees that we read batches at a - // [u16;31] boundary. - while !buffer.is_empty() { - let bytes_read = stream.read_buf(&mut buffer).await?; - if bytes_read == 0 { - let n_incomplete = (buffer.len() - + std::mem::size_of::<[u16; 31]>() //TODO: make this a const - - 1) - / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const - batch.truncate(batch.len() - n_incomplete); - break; - } - } - - Ok::<_, eyre::Report>(batch) - }, - )); - - // let batch_worker = tokio::task::spawn(async move { - // loop { - // // Collect futures of denominator and share batches - // let streams_future = future::try_join_all( - // self.participants.iter_mut().enumerate().map( - // |(i, stream)| async move { - // let mut batch = vec![[0_u16; 31]; BATCH_SIZE]; - // let mut buffer: &mut [u8] = - // bytemuck::cast_slice_mut(batch.as_mut_slice()); - - // // We can not use read_exact here as we might get EOF before the - // // buffer is full But we should - // // still try to fill the entire buffer. - // // If nothing else, this guarantees that we read batches at a - // // [u16;31] boundary. - // while !buffer.is_empty() { - // let bytes_read = - // stream.read_buf(&mut buffer).await?; - // if bytes_read == 0 { - // let n_incomplete = (buffer.len() - // + std::mem::size_of::<[u16; 31]>() //TODO: make this a const - // - 1) - // / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const - // batch.truncate(batch.len() - n_incomplete); - // break; - // } - // } - - // Ok::<_, eyre::Report>(batch) - // }, - // ), - // ); - - // // Wait on all parts concurrently - // let (denom, shares) = - // tokio::join!(denominator_rx.recv(), streams_future); - - // let mut denom = denom.unwrap_or_default(); - // let mut shares = shares?; - - // // Find the shortest prefix - // let batch_size = shares - // .iter() - // .map(Vec::len) - // .fold(denom.len(), core::cmp::min); - - // denom.truncate(batch_size); - // shares - // .iter_mut() - // .for_each(|batch| batch.truncate(batch_size)); - - // // Send batches - // processed_shares_tx.send((denom, shares)).await?; - // if batch_size == 0 { - // break; - // } - // } - // Ok(()) - // }); - - // (processed_shares_rx, batch_worker) + let (processed_shares_tx, processed_shares_rx) = mpsc::channel(4); + + let participants = self.participants.clone(); + + let batch_worker = tokio::task::spawn(async move { + loop { + let mut participants = participants.lock().await; + // Collect futures of denominator and share batches + let streams_future = future::try_join_all( + participants.iter_mut().enumerate().map( + |(i, stream)| async move { + let mut batch = vec![[0_u16; 31]; BATCH_SIZE]; + let mut buffer: &mut [u8] = + bytemuck::cast_slice_mut(batch.as_mut_slice()); + + // We can not use read_exact here as we might get EOF before the + // buffer is full But we should + // still try to fill the entire buffer. + // If nothing else, this guarantees that we read batches at a + // [u16;31] boundary. + while !buffer.is_empty() { + let bytes_read = + stream.read_buf(&mut buffer).await?; + if bytes_read == 0 { + let n_incomplete = (buffer.len() + + std::mem::size_of::<[u16; 31]>() //TODO: make this a const + - 1) + / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const + batch.truncate(batch.len() - n_incomplete); + break; + } + } + + Ok::<_, eyre::Report>(batch) + }, + ), + ); + + // Wait on all parts concurrently + let (denom, shares) = + tokio::join!(denominator_rx.recv(), streams_future); + + let mut denom = denom.unwrap_or_default(); + let mut shares = shares?; + + // Find the shortest prefix + let batch_size = shares + .iter() + .map(Vec::len) + .fold(denom.len(), core::cmp::min); + + denom.truncate(batch_size); + shares + .iter_mut() + .for_each(|batch| batch.truncate(batch_size)); + + // Send batches + processed_shares_tx.send((denom, shares)).await?; + if batch_size == 0 { + break; + } + } + Ok(()) + }); - todo!() + (processed_shares_rx, batch_worker) } pub async fn process_results(