diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index db603501c..da56e67e3 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -203,6 +203,10 @@ impl RequestHandler for Inner { let query_id = ext_query_id(&req)?; HelperResponse::from(qp.complete(query_id).await?) } + RouteId::KillQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.kill(query_id)?) + } }) } } diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index fc73caf5d..07018fb14 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -203,10 +203,9 @@ impl GatewaySenders { match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => Arc::clone(entry.get()), Entry::Vacant(entry) => { - let sender = Self::new_sender( - &SendChannelConfig::new::(config, total_records), - channel_id.clone(), - ); + let config = SendChannelConfig::new::(config, total_records); + tracing::trace!("send configuration for {channel_id:?}: {config:?}"); + let sender = Self::new_sender(&config, channel_id.clone()); entry.insert(Arc::clone(&sender)); tokio::spawn({ diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index 42981d097..525edb67e 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -12,7 +12,7 @@ use crate::{ }, query::{ NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError, - QueryStatus, QueryStatusError, + QueryKillStatus, QueryKilled, QueryStatus, QueryStatusError, }, sync::{Arc, Mutex, Weak}, }; @@ -135,6 +135,13 @@ impl From for HelperResponse { } } +impl From for HelperResponse { + fn from(value: QueryKilled) -> Self { + let v = serde_json::to_vec(&json!({"query_id": value.0, "status": "killed"})).unwrap(); + Self { body: v } + } +} + impl> From for HelperResponse { fn from(value: R) -> Self { let v = value.as_ref().to_bytes(); @@ -156,6 +163,8 @@ pub enum Error { #[error(transparent)] QueryStatus(#[from] QueryStatusError), #[error(transparent)] + QueryKill(#[from] QueryKillStatus), + #[error(transparent)] DeserializationFailure(#[from] serde_json::Error), #[error("MalformedRequest: {0}")] BadRequest(BoxError), diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 3c1a9e926..cd7324e89 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -119,7 +119,8 @@ impl InMemoryTransport { | RouteId::PrepareQuery | RouteId::QueryInput | RouteId::QueryStatus - | RouteId::CompleteQuery => { + | RouteId::CompleteQuery + | RouteId::KillQuery => { handler .as_ref() .expect("Handler is set") diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 4d8f44796..3d9c2bb5f 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -16,6 +16,7 @@ pub enum RouteId { QueryInput, QueryStatus, CompleteQuery, + KillQuery, } /// The header/metadata of the incoming request. diff --git a/ipa-core/src/helpers/transport/stream/collection.rs b/ipa-core/src/helpers/transport/stream/collection.rs index 09e4f5e63..f19fd7ce5 100644 --- a/ipa-core/src/helpers/transport/stream/collection.rs +++ b/ipa-core/src/helpers/transport/stream/collection.rs @@ -114,6 +114,26 @@ impl StreamCollection { let mut streams = self.inner.lock().unwrap(); streams.clear(); } + + /// Returns the number of streams inside this collection. + /// + /// ## Panics + /// if mutex is poisoned. + #[cfg(test)] + #[must_use] + pub fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } + + /// Returns `true` if this collection is empty. + /// + /// ## Panics + /// if mutex is poisoned. + #[must_use] + #[cfg(test)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } /// Describes the lifecycle of records stream inside [`StreamCollection`] diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 927ae4a4d..1965c15ce 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -533,4 +533,82 @@ pub mod query { pub const AXUM_PATH: &str = "/:query_id/complete"; } + + pub mod kill { + use serde::{Deserialize, Serialize}; + + use crate::{ + helpers::{routing::RouteId, HelperResponse, NoStep, RouteParams}, + protocol::QueryId, + }; + + pub struct Request { + pub query_id: QueryId, + } + + impl RouteParams for Request { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::KillQuery + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + String::new() + } + } + + impl Request { + /// Currently, it is only possible to kill + /// a query by issuing an HTTP request manually. + /// Maybe report collector can support this API, + /// but for now, only tests exercise this path + /// hence methods here are hidden behind feature + /// flags + #[cfg(all(test, unit_test))] + pub fn new(query_id: QueryId) -> Self { + Self { query_id } + } + + #[cfg(all(test, unit_test))] + pub fn try_into_http_request( + self, + scheme: axum::http::uri::Scheme, + authority: axum::http::uri::Authority, + ) -> crate::net::http_serde::OutgoingRequest { + let uri = axum::http::uri::Uri::builder() + .scheme(scheme) + .authority(authority) + .path_and_query(format!( + "{}/{}/kill", + crate::net::http_serde::query::BASE_AXUM_PATH, + self.query_id.as_ref() + )) + .build()?; + Ok(hyper::Request::post(uri).body(axum::body::Body::empty())?) + } + } + + #[derive(Clone, Debug, Serialize, Deserialize)] + pub struct ResponseBody { + pub query_id: QueryId, + pub status: String, + } + + impl From for ResponseBody { + fn from(value: HelperResponse) -> Self { + serde_json::from_slice(value.into_body().as_slice()).unwrap() + } + } + + pub const AXUM_PATH: &str = "/:query_id/kill"; + } } diff --git a/ipa-core/src/net/server/handlers/query/kill.rs b/ipa-core/src/net/server/handlers/query/kill.rs new file mode 100644 index 000000000..aae68b993 --- /dev/null +++ b/ipa-core/src/net/server/handlers/query/kill.rs @@ -0,0 +1,136 @@ +use axum::{extract::Path, routing::post, Extension, Json, Router}; +use hyper::StatusCode; + +use crate::{ + helpers::{ApiError, BodyStream, Transport}, + net::{ + http_serde::query::{kill, kill::Request}, + server::Error, + Error::QueryIdNotFound, + HttpTransport, + }, + protocol::QueryId, + query::QueryKillStatus, + sync::Arc, +}; + +async fn handler( + transport: Extension>, + Path(query_id): Path, +) -> Result, Error> { + let req = Request { query_id }; + let transport = Transport::clone_ref(&*transport); + match transport.dispatch(req, BodyStream::empty()).await { + Ok(state) => Ok(Json(kill::ResponseBody::from(state))), + Err(ApiError::QueryKill(QueryKillStatus::NoSuchQuery(query_id))) => Err( + Error::application(StatusCode::NOT_FOUND, QueryIdNotFound(query_id)), + ), + Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), + } +} + +pub fn router(transport: Arc) -> Router { + Router::new() + .route(kill::AXUM_PATH, post(handler)) + .layer(Extension(transport)) +} + +#[cfg(all(test, unit_test))] +mod tests { + use axum::{ + body::Body, + http::uri::{Authority, Scheme}, + }; + use hyper::StatusCode; + + use crate::{ + helpers::{ + make_owned_handler, + routing::{Addr, RouteId}, + ApiError, BodyStream, HelperIdentity, HelperResponse, + }, + net::{ + http_serde, + server::handlers::query::test_helpers::{ + assert_fails_with, assert_fails_with_handler, assert_success_with, + }, + }, + protocol::QueryId, + query::{QueryKillStatus, QueryKilled}, + }; + + #[tokio::test] + async fn calls_kill() { + let expected_query_id = QueryId; + + let handler = make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { + let RouteId::KillQuery = addr.route else { + panic!("unexpected call: {addr:?}"); + }; + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(QueryKilled(expected_query_id))) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId); + let req = req + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_success_with(req, handler).await; + } + + #[tokio::test] + async fn no_such_query() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(QueryKillStatus::NoSuchQuery(QueryId).into()) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId) + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_fails_with_handler(req, handler, StatusCode::NOT_FOUND).await; + } + + #[tokio::test] + async fn unknown_error() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(ApiError::DeserializationFailure( + serde_json::from_str::<()>("not-a-json").unwrap_err(), + )) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId) + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_fails_with_handler(req, handler, StatusCode::INTERNAL_SERVER_ERROR).await; + } + + struct OverrideReq { + query_id: String, + } + + impl From for hyper::Request { + fn from(val: OverrideReq) -> Self { + let uri = format!( + "http://localhost{}/{}/kill", + http_serde::query::BASE_AXUM_PATH, + val.query_id + ); + hyper::Request::post(uri).body(Body::empty()).unwrap() + } + } + + #[tokio::test] + async fn malformed_query_id() { + let req = OverrideReq { + query_id: "not-a-query-id".into(), + }; + + assert_fails_with(req.into(), StatusCode::BAD_REQUEST).await; + } +} diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 49f18e0a8..616308eea 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -1,5 +1,6 @@ mod create; mod input; +mod kill; mod prepare; mod results; mod status; @@ -31,6 +32,7 @@ pub fn query_router(transport: Arc) -> Router { .merge(create::router(Arc::clone(&transport))) .merge(input::router(Arc::clone(&transport))) .merge(status::router(Arc::clone(&transport))) + .merge(kill::router(Arc::clone(&transport))) .merge(results::router(transport)) } @@ -139,6 +141,19 @@ pub mod test_helpers { assert_eq!(resp.status(), expected_status); } + pub async fn assert_fails_with_handler( + req: hyper::Request, + handler: Arc>, + expected_status: StatusCode, + ) { + let test_server = TestServer::builder() + .with_request_handler(handler) + .build() + .await; + let resp = test_server.server.handle_req(req).await; + assert_eq!(resp.status(), expected_status); + } + pub async fn assert_success_with( req: hyper::Request, handler: Arc>, diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 81d4bdcce..508bfc8d5 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -135,7 +135,7 @@ impl HttpTransport { .expect("A Handler should be set by now") .handle(Addr::from_route(None, req), body); - if let RouteId::CompleteQuery = route_id { + if let RouteId::CompleteQuery | RouteId::KillQuery = route_id { ClearOnDrop { transport: Arc::clone(&self), inner: r, @@ -210,7 +210,8 @@ impl Transport for Arc { evt @ (RouteId::QueryInput | RouteId::ReceiveQuery | RouteId::QueryStatus - | RouteId::CompleteQuery) => { + | RouteId::CompleteQuery + | RouteId::KillQuery) => { unimplemented!( "attempting to send client-specific request {evt:?} to another helper" ) @@ -283,7 +284,10 @@ mod tests { use crate::{ config::{NetworkConfig, ServerConfig}, ff::{FieldType, Fp31, Serializable}, - helpers::query::{QueryInput, QueryType::TestMultiply}, + helpers::{ + make_owned_handler, + query::{QueryInput, QueryType::TestMultiply}, + }, net::{ client::ClientIdentity, test::{get_test_identity, TestConfig, TestConfigBuilder, TestServer}, @@ -295,6 +299,32 @@ mod tests { static STEP: Lazy = Lazy::new(|| Gate::from("http-transport")); + #[tokio::test] + async fn clean_on_kill() { + let noop_handler = make_owned_handler(|_, _| async move { + { + Ok(HelperResponse::ok()) + } + }); + let TestServer { transport, .. } = TestServer::builder() + .with_request_handler(Arc::clone(&noop_handler)) + .build() + .await; + + transport.record_streams.add_stream( + (QueryId, HelperIdentity::ONE, Gate::default()), + BodyStream::empty(), + ); + assert_eq!(1, transport.record_streams.len()); + + Transport::clone_ref(&transport) + .dispatch((RouteId::KillQuery, QueryId), BodyStream::empty()) + .await + .unwrap(); + + assert!(transport.record_streams.is_empty()); + } + #[tokio::test] async fn receive_stream() { let (tx, rx) = channel::>>(1); diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index f7021d988..cdbfddbce 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> { self.total_records = self.total_records.overwrite(total_records.into()); } + pub fn records_per_batch(&self) -> usize { + self.records_per_batch + } + fn batch_offset(&self, record_id: RecordId) -> usize { let batch_index = usize::from(record_id) / self.records_per_batch; batch_index @@ -110,7 +114,7 @@ impl<'a, B> Batcher<'a, B> { while self.batches.len() <= batch_offset { let (validation_result, _) = watch::channel::(false); let state = BatchState { - batch: (self.batch_constructor)(self.first_batch + batch_offset), + batch: (self.batch_constructor)(self.first_batch + self.batches.len()), validation_result, pending_count: 0, pending_records: bitvec![0; self.records_per_batch], @@ -292,6 +296,23 @@ mod tests { ); } + #[test] + fn makes_batches_out_of_order() { + // Regression test for a bug where, when adding batches i..j to fill in a gap in + // the batch deque prior to out-of-order requested batch j, the batcher passed + // batch index `j` to the constructor for all of them, as opposed to the correct + // sequence of indices i..=j. + + let batcher = Batcher::new(1, 2, Box::new(std::convert::identity)); + let mut batcher = batcher.lock().unwrap(); + + batcher.get_batch(RecordId::from(1)); + batcher.get_batch(RecordId::from(0)); + + assert_eq!(batcher.get_batch(RecordId::from(0)).batch, 0); + assert_eq!(batcher.get_batch(RecordId::from(1)).batch, 1); + } + #[tokio::test] async fn validates_batches() { let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new())); diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 73cfda40c..70dd3d2af 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -29,6 +29,7 @@ use crate::{ pub struct DZKPUpgraded<'a> { validator_inner: Weak>, base_ctx: MaliciousContext<'a>, + active_work: NonZeroUsize, } impl<'a> DZKPUpgraded<'a> { @@ -36,9 +37,30 @@ impl<'a> DZKPUpgraded<'a> { validator_inner: &Arc>, base_ctx: MaliciousContext<'a>, ) -> Self { + let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); + let active_work = if records_per_batch == 1 { + // If records_per_batch is 1, let active_work be anything. This only happens + // in tests; there shouldn't be a risk of deadlocks with one record per + // batch; and UnorderedReceiver capacity (which is set from active_work) + // must be at least two. + base_ctx.active_work() + } else { + // Adjust active_work to match records_per_batch. If it is less, we will + // certainly stall, since every record in the batch remains incomplete until + // the batch is validated. It is possible that it can be larger, but making + // it the same seems safer for now. + let active_work = NonZeroUsize::new(records_per_batch).unwrap(); + tracing::debug!( + "Changed active_work from {} to {} to match batch size", + base_ctx.active_work().get(), + active_work, + ); + active_work + }; Self { validator_inner: Arc::downgrade(validator_inner), base_ctx, + active_work, } } @@ -130,7 +152,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> { impl<'a> SeqJoin for DZKPUpgraded<'a> { fn active_work(&self) -> NonZeroUsize { - self.base_ctx.active_work() + self.active_work } } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 517d4db46..f5586336f 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -825,35 +825,158 @@ mod tests { }; use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; - use futures::{StreamExt, TryStreamExt}; + use futures::{stream, StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prop_compose, proptest, sample::select}; - use rand::{thread_rng, Rng}; + use proptest::{ + prelude::{Just, Strategy}, + prop_compose, prop_oneof, proptest, + test_runner::Config as ProptestConfig, + }; + use rand::{distributions::Standard, prelude::Distribution}; use crate::{ error::Error, - ff::{boolean::Boolean, Fp61BitPrime}, + ff::{ + boolean::Boolean, + boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, + Fp61BitPrime, + }, protocol::{ - basics::SecureMul, + basics::{select, BooleanArrayMul, SecureMul}, context::{ dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, }, - Context, UpgradableContext, TEST_DZKP_STEPS, + Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, + UpgradableContext, TEST_DZKP_STEPS, }, Gate, RecordId, }, + rand::{thread_rng, Rng}, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, Vectorizable, }, - seq_join::{seq_join, SeqJoin}, + seq_join::seq_join, + sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld}, }; + async fn test_select_semi_honest() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let sh_ctx = v.context(); + + let result = select( + sh_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + #[tokio::test] - async fn dzkp_malicious() { + async fn select_semi_honest() { + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + } + + async fn test_select_malicious() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.malicious_contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context(); + + let result = select( + m_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + + #[tokio::test] + async fn select_malicious() { + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + } + + #[tokio::test] + async fn two_multiplies_malicious() { const COUNT: usize = 32; let mut rng = thread_rng(); @@ -913,9 +1036,54 @@ mod tests { } } + /// Similar to `test_select_malicious`, but operating on vectors + async fn multi_select_malicious(count: usize, max_multiplications_per_gate: usize) + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let mut rng = thread_rng(); + + let bit: Vec = repeat_with(|| rng.gen::()).take(count).collect(); + let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); + let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); + + let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); + let m_ctx = v.context(); + + v.validated_seq_join(stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + )) + .try_collect() + .await + }, + ) + .await + .map(Result::unwrap); + + let ab: Vec = [ab0, ab1, ab2].reconstruct(); + + for i in 0..count { + assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); + } + } + /// test for testing `validated_seq_join` - /// similar to `complex_circuit` in `validator.rs` - async fn complex_circuit_dzkp( + /// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) + async fn chained_multiplies_dzkp( count: usize, max_multiplications_per_gate: usize, ) -> Result<(), Error> { @@ -945,7 +1113,7 @@ mod tests { .map(|(ctx, input_shares)| async move { let v = ctx .set_total_records(count - 1) - .dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get()); + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); let m_ctx = v.context(); let m_results = v @@ -1021,19 +1189,63 @@ mod tests { Ok(()) } + fn record_count_strategy() -> impl Strategy { + // The chained_multiplies test has count - 1 records, so 1 is not a valid input size. + // It is for multi_select though. + prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)] + } + + fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy { + let max_max_mults = record_count.min(128); + prop_oneof![ + 1usize..=max_max_mults, + (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) + ] + } + prop_compose! { - fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { - (1usize< (usize, usize) + { + (record_count, max_mults) } } proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] #[test] - fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){ - let future = async { - let _ = complex_circuit_dzkp(count, multiplication_amount).await; - }; - tokio::runtime::Runtime::new().unwrap().block_on(future); + fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { + println!("record_count {record_count} batch {max_multiplications_per_gate}"); + if record_count / max_multiplications_per_gate >= 192 { + // TODO: #1269, or even if we don't fix that, don't hardcode the limit. + println!("skipping config because batch count exceeds limit of 192"); + } + // This condition is correct only for active_work = 16 and record size of 1 byte. + else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { + // TODO: #1300, read_size | batch_size. + // Note: for active work < 2048, read size matches active work. + + // Besides read_size | batch_size, there is also a constraint + // something like active_work > read_size + batch_size - 1. + println!("skipping config due to read_size vs. batch_size constraints"); + } else { + tokio::runtime::Runtime::new().unwrap().block_on(async { + chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + }); + } } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index a93a8edfb..18b9b8e29 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -178,6 +178,16 @@ pub struct Upgraded<'a, F: ExtendableField> { impl<'a, F: ExtendableField> Upgraded<'a, F> { pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { + // The DZKP malicious context adjusts active_work to match records_per_batch. + // The MAC validator currently configures the batcher with records_per_batch = + // active_work. If the latter behavior changes, this code may need to be + // updated. + let records_per_batch = batch.lock().unwrap().records_per_batch(); + let active_work = ctx.active_work().get(); + assert_eq!( + records_per_batch, active_work, + "Expect MAC validation batch size ({records_per_batch}) to match active work ({active_work})", + ); Self { batch: Arc::downgrade(batch), base_ctx: ctx, diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 33303fb9b..e57ae3c6a 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { // TODO: Right now we set the batch work to be equal to active_work, // but it does not need to be. We can make this configurable if needed. - let records_per_batch = ctx.active_work().get().min(total_records.get()); + let records_per_batch = ctx.active_work().get(); Self { protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 267631da0..9a1f8f278 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -510,8 +510,8 @@ where protocol: &Step::Attribute, validate: &Step::AttributeValidate, }, - // The size of a single batch should not exceed the active work limit, - // otherwise it will stall + // TODO: this should not be necessary, but probably can't be removed + // until we align read_size with the batch size. std::cmp::min(sh_ctx.active_work().get(), chunk_size), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); diff --git a/ipa-core/src/query/mod.rs b/ipa-core/src/query/mod.rs index aaa437b7a..6e6650862 100644 --- a/ipa-core/src/query/mod.rs +++ b/ipa-core/src/query/mod.rs @@ -8,7 +8,7 @@ use completion::Handle as CompletionHandle; pub use executor::Result as ProtocolResult; pub use processor::{ NewQueryError, PrepareQueryError, Processor as QueryProcessor, QueryCompletionError, - QueryInputError, QueryStatusError, + QueryInputError, QueryKillStatus, QueryKilled, QueryStatusError, }; pub use runner::OprfIpaQuery; pub use state::QueryStatus; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index c399bb019..a8694012e 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -5,6 +5,7 @@ use std::{ }; use futures::{future::try_join, stream}; +use serde::Serialize; use crate::{ error::Error as ProtocolError, @@ -328,6 +329,36 @@ impl Processor { Ok(handle.await?) } + + /// Terminates a query with the given id. If query is running, then it + /// is unregistered and its task is terminated. + /// + /// ## Errors + /// if query is not registered on this helper. + /// + /// ## Panics + /// If failed to obtain exclusive access to the query collection. + pub fn kill(&self, query_id: QueryId) -> Result { + let mut queries = self.queries.inner.lock().unwrap(); + let Some(state) = queries.remove(&query_id) else { + return Err(QueryKillStatus::NoSuchQuery(query_id)); + }; + + if let QueryState::Running(handle) = state { + handle.join_handle.abort(); + } + + Ok(QueryKilled(query_id)) + } +} + +#[derive(Clone, Serialize)] +pub struct QueryKilled(pub QueryId); + +#[derive(thiserror::Error, Debug)] +pub enum QueryKillStatus { + #[error("failed to kill a query: {0} does not exist.")] + NoSuchQuery(QueryId), } #[cfg(all(test, unit_test))] @@ -549,6 +580,102 @@ mod tests { } } + mod kill { + use std::sync::Arc; + + use crate::{ + ff::FieldType, + helpers::{ + query::{ + QueryConfig, + QueryType::{TestAddInPrimeField, TestMultiply}, + }, + HandlerBox, HelperIdentity, InMemoryMpcNetwork, Transport, + }, + protocol::QueryId, + query::{ + processor::{tests::respond_ok, Processor}, + state::{QueryState, RunningQuery}, + QueryKillStatus, + }, + test_executor::run, + }; + + #[test] + fn non_existent_query() { + let processor = Processor::default(); + assert!(matches!( + processor.kill(QueryId), + Err(QueryKillStatus::NoSuchQuery(QueryId)) + )); + } + + #[test] + fn existing_query() { + run(|| async move { + let h2 = respond_ok(); + let h3 = respond_ok(); + let network = InMemoryMpcNetwork::new([ + None, + Some(HandlerBox::owning_ref(&h2)), + Some(HandlerBox::owning_ref(&h3)), + ]); + let identities = HelperIdentity::make_three(); + let processor = Processor::default(); + let transport = network.transport(identities[0]); + processor + .new_query( + Transport::clone_ref(&transport), + QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(), + ) + .await + .unwrap(); + + processor.kill(QueryId).unwrap(); + + // start query again - it should work because the query was killed + processor + .new_query( + transport, + QueryConfig::new(TestAddInPrimeField, FieldType::Fp32BitPrime, 1).unwrap(), + ) + .await + .unwrap(); + }); + } + + #[test] + fn aborts_protocol_task() { + run(|| async move { + let processor = Processor::default(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let counter = Arc::new(1); + let task = tokio::spawn({ + let counter = Arc::clone(&counter); + async move { + loop { + tokio::task::yield_now().await; + let _ = *counter.as_ref(); + } + } + }); + processor.queries.inner.lock().unwrap().insert( + QueryId, + QueryState::Running(RunningQuery { + result: rx, + join_handle: task, + }), + ); + + assert_eq!(2, Arc::strong_count(&counter)); + processor.kill(QueryId).unwrap(); + while Arc::strong_count(&counter) > 1 { + tokio::task::yield_now().await; + } + }); + } + } + mod e2e { use std::time::Duration; diff --git a/ipa-core/src/secret_sharing/vector/impls.rs b/ipa-core/src/secret_sharing/vector/impls.rs index 974bfd259..b5b043b4d 100644 --- a/ipa-core/src/secret_sharing/vector/impls.rs +++ b/ipa-core/src/secret_sharing/vector/impls.rs @@ -55,107 +55,6 @@ macro_rules! boolean_vector { AdditiveShare::new(*value.left_arr(), *value.right_arr()) } } - - #[cfg(all(test, unit_test))] - mod tests { - use std::iter::zip; - - use super::*; - use crate::{ - error::Error, - protocol::{ - basics::select, - context::{ - dzkp_validator::DZKPValidator, Context, UpgradableContext, - TEST_DZKP_STEPS, - }, - RecordId, - }, - rand::{thread_rng, Rng}, - secret_sharing::into_shares::IntoShares, - test_fixture::{join3v, Reconstruct, TestWorld}, - }; - - #[tokio::test] - async fn simplest_circuit_malicious() { - let world = TestWorld::default(); - let context = world.malicious_contexts(); - let mut rng = thread_rng(); - - let bit = rng.gen::(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); - let m_ctx = v.context(); - - let result = select( - m_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - - #[tokio::test] - async fn simplest_circuit_semi_honest() { - let world = TestWorld::default(); - let context = world.contexts(); - let mut rng = thread_rng(); - - let bit = rng.gen::(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); - let sh_ctx = v.context(); - - let result = select( - sh_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - } } }; } diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index bc1a463dd..792b1cc37 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -186,8 +186,8 @@ mod tests { const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ (0, 0.0), (647_634, 0.01), - (137_626, 0.01), - (20_652, 0.02), + (137_626, 0.02), + (20_652, 0.03), (3_085, 0.05), (463, 0.12), (70, 0.5),