Skip to content

Commit

Permalink
parallel decryption
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Dec 19, 2024
1 parent c626296 commit d607dec
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions ipa-core/src/query/runner/hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::{
convert::{Infallible, Into},
marker::PhantomData,
ops::Add,
sync::Arc,
convert::{Infallible, Into}, marker::PhantomData, ops::Add, sync::Arc
};

use futures::{future::lazy, stream, StreamExt};
use futures::{future::lazy, StreamExt, TryStreamExt};
use generic_array::ArrayLength;

use super::QueryResult;
Expand All @@ -19,8 +16,7 @@ use crate::{
Serializable, U128Conversions,
},
helpers::{
query::{DpMechanism, HybridQueryParams, QueryConfig, QuerySize},
setup_cross_shard_prss, BodyStream, Gateway, LengthDelimitedStream,
query::{DpMechanism, HybridQueryParams, QueryConfig, QuerySize}, setup_cross_shard_prss, stream::TryFlattenItersExt, BodyStream, Gateway, LengthDelimitedStream
},
hpke::PrivateKeyRegistry,
protocol::{
Expand Down Expand Up @@ -118,26 +114,25 @@ where
}

let stream = LengthDelimitedStream::<EncryptedHybridReport<BA8, BA3>, _>::new(input_stream)
.map(|enc_reports_res| {
lazy(|_| stream::iter(match enc_reports_res {
Ok(enc_reports) => {
println!("decrypting on {}", tokio::task::id());
enc_reports.into_iter().map(|enc_report| {
let dec_report = enc_report
.decrypt(key_registry.as_ref())
.map_err(Into::<Error>::into);
let unique_tag = UniqueTag::from_unique_bytes(&enc_report);
dec_report.map(|dec_report1| (dec_report1, unique_tag))
})
.collect::<Vec<_>>()
.map_err(Into::into)
.try_flatten_iters()
.map(|enc_report_res| {
lazy(|_| match enc_report_res {
Ok(enc_report) => {
let dec_report = enc_report
.decrypt(key_registry.as_ref())
.map_err(Into::<Error>::into);
let unique_tag = UniqueTag::from_unique_bytes(&enc_report);
dec_report.map(|dec_report1| (dec_report1, unique_tag))
}
Err(err) => vec![Err(err.into())],
}))
});
Err(err) => Err(err.into()),
})
})
.take(sz);

let (decrypted_reports, resharded_tags) = reshard_aad(
ctx.narrow(&HybridStep::ReshardByTag),
seq_join(ctx.active_work(), stream).flatten().take(sz),
seq_join(ctx.active_work(), stream),
|ctx, _, tag| tag.shard_picker(ctx.shard_count()),
)
.await?;
Expand Down

0 comments on commit d607dec

Please sign in to comment.