From 36b9c7dbe00907d40a27a2a219c923a3b038a1dd Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 16 Oct 2023 13:45:04 +1100 Subject: [PATCH 1/2] Restructure the bucketing code Move the function to a file on its own. Skip iterations when the number of buckets is less than half the number of bits and `robust` is false. Refactor tests. Add an overflow test and guard against having more than 128 buckets (which we can't support). Fix gen_range() invocations. Closes #803. --- Cargo.toml | 2 +- src/ff/prime_field.rs | 12 + src/protocol/prf_sharding/bucket.rs | 234 ++++++++++++++++++ .../prf_sharding/feature_label_dot_product.rs | 10 +- src/protocol/prf_sharding/mod.rs | 181 +------------- 5 files changed, 261 insertions(+), 178 deletions(-) create mode 100644 src/protocol/prf_sharding/bucket.rs diff --git a/Cargo.toml b/Cargo.toml index e264d503f..2ba6d8b67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,7 +106,7 @@ cfg_aliases = "0.1.1" command-fds = "0.2.2" hex = "0.4" permutation = "0.4.1" -proptest = "1.0.0" +proptest = "1.3" tempfile = "3" tokio-rustls = { version = "0.24.0", features = ["dangerous_configuration"] } diff --git a/src/ff/prime_field.rs b/src/ff/prime_field.rs index 49962a83d..56391a8a0 100644 --- a/src/ff/prime_field.rs +++ b/src/ff/prime_field.rs @@ -132,6 +132,18 @@ macro_rules! field_impl { } } + impl std::iter::Sum for $field { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |a, b| a + b) + } + } + + impl<'a> std::iter::Sum<&'a $field> for $field { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |a, b| a + *b) + } + } + impl TryFrom for $field { type Error = crate::error::Error; diff --git a/src/protocol/prf_sharding/bucket.rs b/src/protocol/prf_sharding/bucket.rs new file mode 100644 index 000000000..6349f91a6 --- /dev/null +++ b/src/protocol/prf_sharding/bucket.rs @@ -0,0 +1,234 @@ +use embed_doc_image::embed_doc_image; + +use crate::{ + error::Error, + ff::{GaloisField, PrimeField, Serializable}, + protocol::{ + basics::SecureMul, context::UpgradedContext, prf_sharding::BinaryTreeDepthStep, + step::BitOpStep, RecordId, + }, + secret_sharing::{ + replicated::malicious::ExtendableField, BitDecomposed, Linear as LinearSecretSharing, + }, +}; + +#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")] +/// This function moves a single value to a correct bucket using tree aggregation approach +/// +/// Here is how it works +/// The combined value, [`value`] forms the root of a binary tree as follows: +/// ![Tree propagation][tree-aggregation] +/// +/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications. +/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]_r by the value to get [`bd_key`]_r.[`value`]. From that, +/// produce [`row_contribution`]_r,0 =[`value`]-[`bd_key`]_r.[`value`] and [`row_contribution`]_r,1=[`bd_key`]_r.[`value`]. +/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree. +/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes: +/// [`row_contribution`]_r+1,q,0 =[`row_contribution`]_r,q - [`bd_key`]_r+1.[`row_contribution`]_r,q and [`row_contribution`]_r+1,q,1 =[`bd_key`]_r+1.[`row_contribution`]_r,q. +/// The work of each iteration therefore doubles relative to the one preceding. +/// +/// In case a malicious entity sends a out of range breakdown key (i.e. greater than the max count) to this function, we need to do some +/// extra processing to ensure contribution doesn't end up in a wrong bucket. However, this requires extra multiplications. +/// This would potentially not be needed in IPA (as the breakdown key is provided by the report collector, so a bad value only spoils their own result) but useful for PAM. +/// This can be by passing `robust` as true. +pub async fn move_single_value_to_bucket( + ctx: C, + record_id: RecordId, + bd_key: BitDecomposed, + value: S, + breakdown_count: usize, + robust: bool, +) -> Result, Error> +where + BK: GaloisField, + C: UpgradedContext, + S: LinearSecretSharing + Serializable + SecureMul, + F: PrimeField + ExtendableField, +{ + let mut step: usize = 1 << BK::BITS; + + assert!( + breakdown_count <= 1 << BK::BITS, + "Asking for more buckets ({breakdown_count}) than bits in the key ({}) allow", + BK::BITS + ); + assert!( + breakdown_count <= 128, + "Our step implementation (BitOpStep) cannot go past 64" + ); + let mut row_contribution = vec![value; breakdown_count]; + + for (tree_depth, bit_of_bdkey) in bd_key.iter().enumerate().rev() { + let span = step >> 1; + if !robust && span > breakdown_count { + step = span; + continue; + } + + let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth)); + let mut futures = Vec::with_capacity(breakdown_count / step); + + for (i, tree_index) in (0..breakdown_count).step_by(step).enumerate() { + let bit_c = depth_c.narrow(&BitOpStep::from(i)); + + if robust || tree_index + span < breakdown_count { + futures.push(row_contribution[tree_index].multiply(bit_of_bdkey, bit_c, record_id)); + } + } + let contributions = ctx.parallel_join(futures).await?; + + for (index, bdbit_contribution) in contributions.into_iter().enumerate() { + let left_index = index * step; + let right_index = left_index + span; + + row_contribution[left_index] -= &bdbit_contribution; + if right_index < breakdown_count { + row_contribution[right_index] = bdbit_contribution; + } + } + step = span; + } + Ok(row_contribution) +} + +#[cfg(all(test, unit_test))] +pub mod tests { + use rand::thread_rng; + + use crate::{ + ff::{Field, Fp32BitPrime, Gf5Bit, Gf8Bit}, + protocol::{ + context::{Context, UpgradableContext, Validator}, + prf_sharding::bucket::move_single_value_to_bucket, + RecordId, + }, + rand::Rng, + secret_sharing::SharedValue, + test_executor::run, + test_fixture::{get_bits, Reconstruct, Runner, TestWorld}, + }; + + const MAX_BREAKDOWN_COUNT: usize = 1 << Gf5Bit::BITS; + const VALUE: u32 = 10; + + async fn move_to_bucket(count: usize, breakdown_key: usize, robust: bool) -> Vec { + let breakdown_key_bits = + get_bits::(breakdown_key.try_into().unwrap(), Gf5Bit::BITS); + let value = Fp32BitPrime::truncate_from(VALUE); + + TestWorld::default() + .semi_honest( + (breakdown_key_bits, value), + |ctx, (breakdown_key_share, value_share)| async move { + let validator = ctx.validator(); + let ctx = validator.context(); + move_single_value_to_bucket::( + ctx.set_total_records(1), + RecordId::from(0), + breakdown_key_share, + value_share, + count, + robust, + ) + .await + .unwrap() + }, + ) + .await + .reconstruct() + } + + #[test] + fn semi_honest_move_in_range() { + run(|| async move { + let mut rng = thread_rng(); + let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); + let breakdown_key = rng.gen_range(0..count); + let mut expected = vec![Fp32BitPrime::ZERO; count]; + expected[breakdown_key] = Fp32BitPrime::truncate_from(VALUE); + + let result = move_to_bucket(count, breakdown_key, false).await; + assert_eq!(result, expected, "expected value at index {breakdown_key}"); + }); + } + + #[test] + fn semi_honest_move_in_range_robust() { + run(|| async move { + let mut rng = thread_rng(); + let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); + let breakdown_key = rng.gen_range(0..count); + let mut expected = vec![Fp32BitPrime::ZERO; count]; + expected[breakdown_key] = Fp32BitPrime::truncate_from(VALUE); + + let result = move_to_bucket(count, breakdown_key, true).await; + assert_eq!(result, expected, "expected value at index {breakdown_key}"); + }); + } + + #[test] + fn semi_honest_move_out_of_range() { + run(move || async move { + let mut rng: rand::rngs::ThreadRng = thread_rng(); + let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1); + let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); + + let result = move_to_bucket(count, breakdown_key, false).await; + assert_eq!(result.len(), count); + assert_eq!( + result.into_iter().sum::(), + Fp32BitPrime::truncate_from(VALUE) + ); + }); + } + + #[test] + fn semi_honest_move_out_of_range_robust() { + run(move || async move { + let mut rng: rand::rngs::ThreadRng = thread_rng(); + let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1); + let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); + + let result = move_to_bucket(count, breakdown_key, true).await; + assert_eq!(result.len(), count); + assert!(result.into_iter().all(|x| x == Fp32BitPrime::ZERO)); + }); + } + + #[test] + #[should_panic] + fn move_out_of_range_too_many_buckets_type() { + run(move || async move { + _ = move_to_bucket(MAX_BREAKDOWN_COUNT + 1, 0, false).await; + }); + } + + #[test] + #[should_panic] + fn move_out_of_range_too_many_buckets_steps() { + run(move || async move { + let breakdown_key_bits = get_bits::(0, Gf8Bit::BITS); + let value = Fp32BitPrime::truncate_from(VALUE); + + _ = TestWorld::default() + .semi_honest( + (breakdown_key_bits, value), + |ctx, (breakdown_key_share, value_share)| async move { + let validator = ctx.validator(); + let ctx = validator.context(); + move_single_value_to_bucket::( + ctx.set_total_records(1), + RecordId::from(0), + breakdown_key_share, + value_share, + 129, + false, + ) + .await + .unwrap() + }, + ) + .await; + }); + } +} diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 0dfc28f3a..9ef0b0ad6 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -261,10 +261,10 @@ where let num_user_rows = rows_for_user.len(); let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); + record_id_for_row_depth[..num_user_rows] + .iter_mut() + .for_each(|count| *count += 1); - for count in &mut record_id_for_row_depth[..num_user_rows] { - *count += 1; - } #[allow(clippy::async_yields_async)] // this is ok, because seq join wants a stream of futures async move { @@ -273,7 +273,7 @@ where })); // Execute all of the async futures (sequentially), and flatten the result - let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) + let flattened_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) .flat_map(|x| stream_iter(x.unwrap())); // modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p` @@ -281,7 +281,7 @@ where prime_field_ctx .narrow(&Step::ModulusConvertFeatureVectorBits) .set_total_records(num_outputs), - flattenned_stream, + flattened_stream, 0..FV::BITS, ); diff --git a/src/protocol/prf_sharding/mod.rs b/src/protocol/prf_sharding/mod.rs index 8b35b371f..cd589bc02 100644 --- a/src/protocol/prf_sharding/mod.rs +++ b/src/protocol/prf_sharding/mod.rs @@ -1,21 +1,18 @@ use std::iter::{repeat, zip}; -use embed_doc_image::embed_doc_image; use futures::{stream::iter as stream_iter, TryStreamExt}; use futures_util::{future::try_join, StreamExt}; use ipa_macros::Step; -use super::{ - basics::if_else, boolean::saturating_sum::SaturatingSum, modulus_conversion::convert_bits, - step::BitOpStep, -}; use crate::{ error::Error, ff::{Field, GaloisField, Gf2, PrimeField, Serializable}, protocol::{ - basics::{SecureMul, ShareKnownValue}, - boolean::or::or, + basics::{if_else, SecureMul, ShareKnownValue}, + boolean::{or::or, saturating_sum::SaturatingSum}, context::{UpgradableContext, UpgradedContext, Validator}, + modulus_conversion::convert_bits, + step::BitOpStep, RecordId, }, secret_sharing::{ @@ -28,6 +25,7 @@ use crate::{ seq_join::{seq_join, seq_try_join_all}, }; +pub mod bucket; #[cfg(feature = "descriptive-gate")] pub mod feature_label_dot_product; @@ -635,7 +633,7 @@ where let record_id: RecordId = RecordId::from(i); let bd_key = bk_bits.unwrap(); async move { - move_single_value_to_bucket::( + bucket::move_single_value_to_bucket::( ctx, record_id, bd_key, @@ -662,86 +660,14 @@ where .await } -#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")] -/// This function moves a single value to a correct bucket using tree aggregation approach -/// -/// Here is how it works -/// The combined value, [`value`] forms the root of a binary tree as follows: -/// ![Tree propagation][tree-aggregation] -/// -/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications. -/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]_r by the value to get [`bd_key`]_r.[`value`]. From that, -/// produce [`row_contribution`]_r,0 =[`value`]-[`bd_key`]_r.[`value`] and [`row_contribution`]_r,1=[`bd_key`]_r.[`value`]. -/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree. -/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes: -/// [`row_contribution`]_r+1,q,0 =[`row_contribution`]_r,q - [`bd_key`]_r+1.[`row_contribution`]_r,q and [`row_contribution`]_r+1,q,1 =[`bd_key`]_r+1.[`row_contribution`]_r,q. -/// The work of each iteration therefore doubles relative to the one preceding. -/// -/// In case a malicious entity sends a out of range breakdown key (i.e. greater than the max count) to this function, we need to do some -/// extra processing to ensure contribution doesn't end up in a wrong bucket. However, this requires extra multiplications. -/// This would potentially not be needed in IPA (as aggregation is done after pre-processing which should be able to throw such input) but useful for PAM. -/// This can be by passing `robust_for_breakdown_key_gt_count as true -async fn move_single_value_to_bucket( - ctx: C, - record_id: RecordId, - bd_key: BitDecomposed, - value: S, - breakdown_count: usize, - robust_for_breakdown_key_gt_count: bool, -) -> Result, Error> -where - BK: GaloisField, - C: UpgradedContext, - S: LinearSecretSharing + Serializable + SecureMul, - F: PrimeField + ExtendableField, -{ - let mut step: usize = 1 << BK::BITS; - - assert!(breakdown_count <= 1 << BK::BITS); - let mut row_contribution = vec![value; breakdown_count]; - - for (tree_depth, bit_of_bdkey) in bd_key.iter().rev().enumerate() { - let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth)); - let span = step >> 1; - let mut futures = Vec::with_capacity(breakdown_count / step); - - for (i, tree_index) in (0..breakdown_count).step_by(step).enumerate() { - let bit_c = depth_c.narrow(&BitOpStep::from(i)); - - if robust_for_breakdown_key_gt_count || tree_index + span < breakdown_count { - futures.push(row_contribution[tree_index].multiply(bit_of_bdkey, bit_c, record_id)); - } - } - let contributions = ctx.parallel_join(futures).await?; - - for (index, bdbit_contribution) in contributions.into_iter().enumerate() { - let left_index = index * step; - let right_index = left_index + span; - - row_contribution[left_index] -= &bdbit_contribution; - if right_index < breakdown_count { - row_contribution[right_index] = bdbit_contribution; - } - } - step = span; - } - Ok(row_contribution) -} - #[cfg(all(test, unit_test))] pub mod tests { - use rand::thread_rng; - use super::{attribution_and_capping, CappedAttributionOutputs, PrfShardedIpaInputRow}; use crate::{ - ff::{Field, Fp32BitPrime, GaloisField, Gf2, Gf3Bit, Gf5Bit, Gf8Bit}, + ff::{Field, Fp32BitPrime, GaloisField, Gf2, Gf3Bit, Gf5Bit}, protocol::{ - context::{Context, UpgradableContext, Validator}, - prf_sharding::{ - attribution_and_capping_and_aggregation, do_aggregation, - move_single_value_to_bucket, - }, - RecordId, + context::{UpgradableContext, Validator}, + prf_sharding::{attribution_and_capping_and_aggregation, do_aggregation}, }, rand::Rng, secret_sharing::{ @@ -1070,93 +996,4 @@ pub mod tests { assert_eq!(result, &expected); }); } - - #[test] - fn semi_honest_move_value_to_single_bucket_in_range() { - const MAX_BREAKDOWN_COUNT: usize = 127; - for _ in 1..10 { - run(|| async move { - let world = TestWorld::default(); - let mut rng: rand::rngs::ThreadRng = thread_rng(); - let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); - let breakdown_key = rng.gen_range(1..count); - - let value: Fp32BitPrime = Fp32BitPrime::truncate_from(10_u128); - let mut expected = vec![Fp32BitPrime::truncate_from(0_u128); count]; - expected[breakdown_key] = value; - - let breakdown_key_bits = - get_bits::(breakdown_key.try_into().unwrap(), Gf8Bit::BITS); - - let result: Vec<_> = world - .semi_honest( - (breakdown_key_bits, value), - |ctx, (breakdown_key_share, value_share)| async move { - let validator = ctx.validator(); - let ctx = validator.context(); - move_single_value_to_bucket::( - ctx.set_total_records(1), - RecordId::from(0), - breakdown_key_share, - value_share, - count, - false, - ) - .await - .unwrap() - }, - ) - .await - .reconstruct(); - assert_eq!(result, expected); - }); - } - } - - #[test] - fn semi_honest_move_value_to_single_bucket_out_of_range() { - const MAX_BREAKDOWN_COUNT: usize = 127; - for robust_for_breakdown_key_gt_count in [true, false] { - for _ in 1..10 { - run(move || async move { - let world = TestWorld::default(); - let mut rng: rand::rngs::ThreadRng = thread_rng(); - let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); - let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); - - let value: Fp32BitPrime = Fp32BitPrime::truncate_from(10_u128); - let expected = vec![Fp32BitPrime::truncate_from(0_u128); count]; - - let breakdown_key_bits = - get_bits::(breakdown_key.try_into().unwrap(), Gf8Bit::BITS); - - let result: Vec<_> = world - .semi_honest( - (breakdown_key_bits, value), - |ctx, (breakdown_key_share, value_share)| async move { - let validator = ctx.validator(); - let ctx = validator.context(); - move_single_value_to_bucket::( - ctx.set_total_records(1), - RecordId::from(0), - breakdown_key_share, - value_share, - count, - robust_for_breakdown_key_gt_count, - ) - .await - .unwrap() - }, - ) - .await - .reconstruct(); - if robust_for_breakdown_key_gt_count { - assert_eq!(result, expected); - } else { - assert_ne!(result, expected); - } - }); - } - } - } } From 05626b1975743c2fab15d700d26365788d82cc5c Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 16 Oct 2023 14:17:01 +1100 Subject: [PATCH 2/2] Remove proptest bump --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 2ba6d8b67..34b4f01d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,7 +106,7 @@ cfg_aliases = "0.1.1" command-fds = "0.2.2" hex = "0.4" permutation = "0.4.1" -proptest = "1.3" +proptest = "1" tempfile = "3" tokio-rustls = { version = "0.24.0", features = ["dangerous_configuration"] }