diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index d1b88af9f..bbd913031 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -25,6 +25,7 @@ use crate::{ labels::{ROLE, STEP}, metrics::{BYTES_SENT, RECORDS_SENT}, }, + utils::non_zero_prev_power_of_two, }; /// Sending end of the gateway channel. @@ -256,14 +257,6 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { - // this computes the greatest positive power of 2 that is - // less than or equal to target. - fn non_zero_prev_power_of_two(target: usize) -> usize { - let bits = usize::BITS - target.leading_zeros(); - - 1 << (std::cmp::max(1, bits) - 1) - } - assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 63a412265..fd616fc9c 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -20,7 +20,7 @@ use crate::{ }, ipa_prf::{ validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - LargeProofGenerator, SmallProofGenerator, + CompressedProofGenerator, FirstProofGenerator, }, Gate, RecordId, RecordIdRange, }, @@ -50,6 +50,9 @@ const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; // A smaller value is used for tests, to enable covering some corner cases with a // reasonable runtime. Some of these tests use TARGET_PROOF_SIZE directly, so for tests // it does need to be a power of two. +// +// TARGET_PROOF_SIZE is closely related to MAX_PROOF_RECURSION; see the assertion that +// `uv_values.len() <= max_uv_values` in `ProofBatch` for more detail. #[cfg(test)] pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] @@ -73,7 +76,7 @@ pub const TARGET_PROOF_SIZE: usize = 50_000_000; // to blocks of 256), leaving some margin is advised. // // The implementation requires that MAX_PROOF_RECURSION is at least 2. -pub const MAX_PROOF_RECURSION: usize = 9; +pub const MAX_PROOF_RECURSION: usize = 14; /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values /// that occur duringa multiplication. @@ -601,8 +604,8 @@ impl Batch { ctx: Base<'_, B>, batch_index: usize, ) -> Result<(), Error> { - const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH - + (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH + const PRSS_RECORDS_PER_BATCH: usize = FirstProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * CompressedProofGenerator::PROOF_LENGTH + 2; // P and Q masks let proof_ctx = ctx.narrow(&Step::GenerateProof); diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index adc1da9ff..bf1a14f3b 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -1,3 +1,5 @@ +use std::cmp::max; + use futures::{stream, StreamExt, TryStreamExt}; use typenum::Const; @@ -17,8 +19,9 @@ use crate::{ protocol::{ basics::{BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, reshard_try_stream, DZKPUpgraded, MacUpgraded, - MaliciousProtocolSteps, ShardedContext, UpgradableContext, Validator, + dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, + reshard_try_stream, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, ShardedContext, + UpgradableContext, Validator, }, hybrid::step::HybridStep, ipa_prf::{ @@ -34,7 +37,9 @@ use crate::{ Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; + // In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count) // ≤ 2^|bk|, with all three values distinct, but at present, there is no runtime configuration and // the latter two must be equal. The implementation of `move_single_value_to_bucket` does support a @@ -64,13 +69,17 @@ pub const CONV_CHUNK: usize = 256; /// Vectorization dimension for PRF pub const PRF_CHUNK: usize = 16; -// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor -// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of -// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so -// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M -// multiplications per batch -const CONV_PROOF_CHUNK: usize = 256; +/// Returns a suitable proof chunk size (in records) for use with `convert_to_fp25519`. +/// +/// We expect 2*256 = 512 gates in total for two additions per conversion. The +/// vectorization factor is `CONV_CHUNK`. Let `len` equal the number of converted +/// shares. The total amount of multiplications is `CONV_CHUNK`*512*len. We want +/// `CONV_CHUNK`*512*len ≈ 50M for a reasonably-sized proof. There is also a constraint +/// on proof chunks to be powers of two, and we don't want to compute a proof chunk +/// of zero when `TARGET_PROOF_SIZE` is smaller for tests. +fn conv_proof_chunk() -> usize { + non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) +} /// This computes the Dodis-Yampolsky PRF value on every match key from input, /// and reshards the reports according to the computed PRF. At the end, reports with the @@ -101,7 +110,7 @@ where protocol: &HybridStep::ConvertFp25519, validate: &HybridStep::ConvertFp25519Validate, }, - CONV_PROOF_CHUNK, + conv_proof_chunk(), ); let m_ctx = validator.context(); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index d3e78fd8c..6a9adb345 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -25,6 +25,7 @@ use crate::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, + utils::non_zero_prev_power_of_two, }; pub(crate) mod breakdown_reveal; @@ -96,8 +97,14 @@ pub type AggResult = Result /// saturating the output) is: /// /// $\sum_{i = 1}^k 2^{k - i} (b + i - 1) \approx 2^k (b + 1) = N (b + 1)$ +/// +/// We set a floor of 2 to avoid computing a chunk of zero when `TARGET_PROOF_SIZE` is +/// smaller for tests. pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) -> usize { - max(2, TARGET_PROOF_SIZE / input_width / (input_item_bits + 1)).next_power_of_two() + non_zero_prev_power_of_two(max( + 2, + TARGET_PROOF_SIZE / input_width / (input_item_bits + 1), + )) } // This is the step count for AggregateChunkStep. We need it to size RecordId arrays. diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 2dabdc3f4..b98fe9612 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -381,7 +381,7 @@ mod tests { helpers::stream::process_slice_by_chunks, protocol::{ context::{dzkp_validator::DZKPValidator, UpgradableContext, TEST_DZKP_STEPS}, - ipa_prf::{CONV_CHUNK, CONV_PROOF_CHUNK, PRF_CHUNK}, + ipa_prf::{conv_proof_chunk, CONV_CHUNK, PRF_CHUNK}, }, rand::thread_rng, secret_sharing::SharedValue, @@ -415,7 +415,7 @@ mod tests { let [res0, res1, res2] = world .semi_honest(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records((COUNT + CONV_CHUNK - 1) / CONV_CHUNK); - let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, CONV_PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, conv_proof_chunk()); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index af451b458..26ce5c5d0 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -1,14 +1,15 @@ use std::{borrow::Borrow, iter::zip, marker::PhantomData}; -#[cfg(all(test, unit_test))] -use crate::ff::Fp31; use crate::{ error::Error::{self, DZKPMasks}, ff::{Fp61BitPrime, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ context::Context, - ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + CompressedProofGenerator, + }, prss::SharedRandomness, RecordId, RecordIdRange, }, @@ -84,8 +85,8 @@ where // compute final uv values let (u_values, v_values) = &mut self.uv_chunks[0]; // shift first element to last position - u_values[SmallProofGenerator::RECURSION_FACTOR - 1] = u_values[0]; - v_values[SmallProofGenerator::RECURSION_FACTOR - 1] = v_values[0]; + u_values[CompressedProofGenerator::RECURSION_FACTOR - 1] = u_values[0]; + v_values[CompressedProofGenerator::RECURSION_FACTOR - 1] = v_values[0]; // set masks in first position u_values[0] = my_p_mask; v_values[0] = my_q_mask; @@ -105,15 +106,11 @@ pub struct ProofGenerator, } -#[cfg(all(test, unit_test))] -pub type TestProofGenerator = ProofGenerator; - // Compression Factor is L // P, Proof size is 2*L - 1 // M, the number of interpolated points is L - 1 // The reason we need these is that Rust doesn't support basic math operations on const generics -pub type SmallProofGenerator = ProofGenerator; -pub type LargeProofGenerator = ProofGenerator; +pub type SmallProofGenerator = ProofGenerator; impl ProofGenerator { // define constants such that they can be used externally @@ -265,7 +262,7 @@ mod test { context::Context, ipa_prf::malicious_security::{ lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{LargeProofGenerator, SmallProofGenerator, TestProofGenerator, UVValues}, + prover::{ProofGenerator, SmallProofGenerator, UVValues}, }, RecordId, RecordIdRange, }, @@ -274,6 +271,9 @@ mod test { test_fixture::{Runner, TestWorld}, }; + type TestProofGenerator = ProofGenerator; + type LargeProofGenerator = ProofGenerator; + fn zip_chunks(a: I, b: J) -> UVValues where I: IntoIterator, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index f4225a0d8..0ffe2adbc 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, iter::zip, num::NonZeroU32, ops::Add}; +use std::{cmp::max, convert::Infallible, iter::zip, num::NonZeroU32, ops::Add}; use futures::{stream, StreamExt, TryStreamExt}; use generic_array::{ArrayLength, GenericArray}; @@ -24,8 +24,8 @@ use crate::{ protocol::{ basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, - UpgradableContext, + dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, + DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ boolean_ops::convert_to_fp25519, @@ -44,6 +44,7 @@ use crate::{ BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; pub(crate) mod aggregation; @@ -58,7 +59,9 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; -pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; +pub type FirstProofGenerator = malicious_security::prover::SmallProofGenerator; +pub type CompressedProofGenerator = malicious_security::prover::SmallProofGenerator; + pub use shuffle::Shuffle; /// Match key type @@ -409,13 +412,17 @@ where Ok(noisy_output_histogram) } -// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor -// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of -// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so -// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M -// multiplications per batch -const CONV_PROOF_CHUNK: usize = 256; +/// Returns a suitable proof chunk size (in records) for use with `convert_to_fp25519`. +/// +/// We expect 2*256 = 512 gates in total for two additions per conversion. The +/// vectorization factor is `CONV_CHUNK`. Let `len` equal the number of converted +/// shares. The total amount of multiplications is `CONV_CHUNK`*512*len. We want +/// `CONV_CHUNK`*512*len ≈ 50M for a reasonably-sized proof. There is also a constraint +/// on proof chunks to be powers of two, and we don't want to compute a proof chunk +/// of zero when `TARGET_PROOF_SIZE` is smaller for tests. +fn conv_proof_chunk() -> usize { + non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) +} #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( @@ -443,7 +450,7 @@ where protocol: &Step::ConvertFp25519, validate: &Step::ConvertFp25519Validate, }, - CONV_PROOF_CHUNK, + conv_proof_chunk(), ); let m_ctx = validator.context(); diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index f6bf2e339..3617ab767 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -52,6 +52,7 @@ use crate::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, + utils::non_zero_prev_power_of_two, }; pub mod feature_label_dot_product; @@ -515,7 +516,10 @@ where // TODO: this override was originally added to work around problems with // read_size vs. batch size alignment. Those are now fixed (in #1332), but this // is still observed to help performance (see #1376), so has been retained. - std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()), + std::cmp::min( + sh_ctx.active_work().get(), + non_zero_prev_power_of_two(chunk_size), + ), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 943dfb1ec..26131224d 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -30,6 +30,7 @@ use crate::{ Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; impl ChunkBuffer for (Vec>, Vec>) @@ -98,7 +99,7 @@ where } fn quicksort_proof_chunk(key_bits: usize) -> usize { - (TARGET_PROOF_SIZE / key_bits / SORT_CHUNK).next_power_of_two() + non_zero_prev_power_of_two(TARGET_PROOF_SIZE / key_bits / SORT_CHUNK) } /// Insecure quicksort using MPC comparisons and a key extraction function `get_key`. diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index cb2754e5f..0a658614e 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -1,9 +1,8 @@ -use std::{array, iter::zip}; +use std::{array, iter::zip, ops::Mul}; -use typenum::{UInt, UTerm, Unsigned, B0, B1}; +use typenum::{Unsigned, U, U8}; use crate::{ - const_assert_eq, error::Error, ff::{Fp61BitPrime, Serializable}, helpers::{Direction, MpcMessage, TotalRecords}, @@ -13,9 +12,9 @@ use crate::{ dzkp_validator::MAX_PROOF_RECURSION, Context, }, - ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{LargeProofGenerator, SmallProofGenerator}, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, @@ -25,8 +24,8 @@ use crate::{ /// This a `ProofBatch` generated by a prover. pub struct ProofBatch { - pub first_proof: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - pub proofs: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, + pub first_proof: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + pub proofs: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, } impl FromIterator for ProofBatch { @@ -35,10 +34,11 @@ impl FromIterator for ProofBatch { // consume the first P elements let first_proof = iterator .by_ref() - .take(LargeProofGenerator::PROOF_LENGTH) - .collect::<[Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH]>(); + .take(FirstProofGenerator::PROOF_LENGTH) + .collect::<[Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH]>(); // consume the rest - let proofs = iterator.collect::>(); + let proofs = + iterator.collect::>(); ProofBatch { first_proof, proofs, @@ -51,7 +51,8 @@ impl ProofBatch { #[allow(clippy::len_without_is_empty)] #[must_use] pub fn len(&self) -> usize { - self.proofs.len() * SmallProofGenerator::PROOF_LENGTH + LargeProofGenerator::PROOF_LENGTH + FirstProofGenerator::PROOF_LENGTH + + self.proofs.len() * CompressedProofGenerator::PROOF_LENGTH } #[allow(clippy::unnecessary_box_returns)] // clippy bug? `Array` exceeds unnecessary-box-size @@ -89,19 +90,19 @@ impl ProofBatch { C: Context, I: Iterator> + Clone, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const LLL: usize = LargeProofGenerator::LAGRANGE_LENGTH; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; - const SLL: usize = SmallProofGenerator::LAGRANGE_LENGTH; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; + const CLL: usize = CompressedProofGenerator::LAGRANGE_LENGTH; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; // precomputation for first proof - let first_denominator = CanonicalLagrangeDenominator::::new(); - let first_lagrange_table = LagrangeTable::::from(first_denominator); + let first_denominator = CanonicalLagrangeDenominator::::new(); + let first_lagrange_table = LagrangeTable::::from(first_denominator); // generate first proof from input iterator let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = - LargeProofGenerator::gen_artefacts_from_recursive_step( + FirstProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, &first_lagrange_table, @@ -110,9 +111,9 @@ impl ProofBatch { // `MAX_PROOF_RECURSION - 2` because: // * The first level of recursion has already happened. - // * We need (SRF - 1) at the last level to have room for the masks. + // * We need (CRF - 1) at the last level to have room for the masks. let max_uv_values: usize = - (SRF - 1) * SRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); + (CRF - 1) * CRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); assert!( uv_values.len() <= max_uv_values, "Proof batch is too large: have {} uv_values, max is {}", @@ -122,9 +123,9 @@ impl ProofBatch { // storage for other proofs let mut my_proofs_left_shares = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); + Vec::<[Fp61BitPrime; CPL]>::with_capacity(MAX_PROOF_RECURSION - 1); let mut shares_of_proofs_from_prover_left = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); + Vec::<[Fp61BitPrime; CPL]>::with_capacity(MAX_PROOF_RECURSION - 1); // generate masks // Prover `P_i` and verifier `P_{i-1}` both compute p(x) @@ -138,25 +139,25 @@ impl ProofBatch { let (q_mask_from_left_prover, my_q_mask) = ctx.prss().generate_fields(prss_record_ids.expect_next()); - let denominator = CanonicalLagrangeDenominator::::new(); - let lagrange_table = LagrangeTable::::from(denominator); + let denominator = CanonicalLagrangeDenominator::::new(); + let lagrange_table = LagrangeTable::::from(denominator); // The last recursion can only include (λ - 1) u/v value pairs, because it needs to put the - // masks in the constant term. If we compress to `uv_values.len() == SRF`, then we need to - // do two more iterations: compressing SRF u/v values to 1 pair of (unmasked) u/v values, + // masks in the constant term. If we compress to `uv_values.len() == CRF`, then we need to + // do two more iterations: compressing CRF u/v values to 1 pair of (unmasked) u/v values, // and then compressing that pair and the masks to the final u/v value. // // There is a test for this corner case in validation.rs. let mut did_set_masks = false; - // recursively generate proofs via SmallProofGenerator + // recursively generate proofs via CompressedProofGenerator while !did_set_masks { - if uv_values.len() < SRF { + if uv_values.len() < CRF { did_set_masks = true; uv_values.set_masks(my_p_mask, my_q_mask).unwrap(); } let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = - SmallProofGenerator::gen_artefacts_from_recursive_step( + CompressedProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, &lagrange_table, @@ -235,25 +236,25 @@ impl ProofBatch { inputs: I, ) -> impl Iterator< Item = ( - [Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR], - [Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR], + [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], + [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], ), > + Clone where I: Iterator> + Clone, { - assert_eq!(BLOCK_SIZE % LargeProofGenerator::RECURSION_FACTOR, 0); + assert_eq!(BLOCK_SIZE % FirstProofGenerator::RECURSION_FACTOR, 0); inputs.flat_map(|(u_block, v_block)| { - (0usize..(BLOCK_SIZE / LargeProofGenerator::RECURSION_FACTOR)).map(move |i| { + (0usize..(BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR)).map(move |i| { ( - <[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>::try_from( - &u_block[i * LargeProofGenerator::RECURSION_FACTOR - ..(i + 1) * LargeProofGenerator::RECURSION_FACTOR], + <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( + &u_block[i * FirstProofGenerator::RECURSION_FACTOR + ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], ) .unwrap(), - <[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>::try_from( - &v_block[i * LargeProofGenerator::RECURSION_FACTOR - ..(i + 1) * LargeProofGenerator::RECURSION_FACTOR], + <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( + &v_block[i * FirstProofGenerator::RECURSION_FACTOR + ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], ) .unwrap(), ) @@ -262,21 +263,12 @@ impl ProofBatch { } } -const_assert_eq!( - MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" -); - -#[rustfmt::skip] -type U1464 = UInt, B0>, B1>, B1>, B0>, B1>, B1>, B1>, B0>, B0>, B0>; - -const ARRAY_LEN: usize = 183; +const ARRAY_LEN: usize = FirstProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * CompressedProofGenerator::PROOF_LENGTH; type Array = [Fp61BitPrime; ARRAY_LEN]; impl Serializable for Box { - type Size = U1464; + type Size = as Mul>::Output; type DeserializationError = ::DeserializationError; diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index b197dcfa3..3f656b3eb 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -5,7 +5,7 @@ use std::{ use futures_util::future::{try_join, try_join4}; use subtle::ConstantTimeEq; -use typenum::{Unsigned, U288, U80}; +use typenum::{Unsigned, U120, U448}; use crate::{ const_assert_eq, @@ -20,11 +20,11 @@ use crate::{ dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, }, ipa_prf::{ - malicious_security::{ - prover::{LargeProofGenerator, SmallProofGenerator}, - verifier::{compute_g_differences, recursively_compute_final_check}, + malicious_security::verifier::{ + compute_g_differences, recursively_compute_final_check, }, validation_protocol::proof_generation::ProofBatch, + CompressedProofGenerator, FirstProofGenerator, }, RecordId, }, @@ -45,10 +45,10 @@ use crate::{ #[derive(Debug)] #[allow(clippy::struct_field_names)] pub struct BatchToVerify { - first_proof_from_left_prover: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - first_proof_from_right_prover: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - proofs_from_left_prover: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, - proofs_from_right_prover: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, + first_proof_from_left_prover: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + first_proof_from_right_prover: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + proofs_from_left_prover: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, + proofs_from_right_prover: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, p_mask_from_right_prover: Fp61BitPrime, q_mask_from_left_prover: Fp61BitPrime, } @@ -105,13 +105,13 @@ impl BatchToVerify { where C: Context, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // exclude for first proof - let exclude_large = u128::try_from(LRF).unwrap(); + let exclude_large = u128::try_from(FRF).unwrap(); // exclude for other proofs - let exclude_small = u128::try_from(SRF).unwrap(); + let exclude_small = u128::try_from(CRF).unwrap(); // generate hashes let my_hashes_prover_left = ProofHashes::generate_hashes(self, Direction::Left); @@ -175,17 +175,17 @@ impl BatchToVerify { U: Iterator + Send, V: Iterator + Send, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // compute p_r - let p_r_right_prover = recursively_compute_final_check::<_, _, LRF, SRF>( + let p_r_right_prover = recursively_compute_final_check::<_, _, FRF, CRF>( u_from_right_prover.into_iter(), challenges_for_right_prover, self.p_mask_from_right_prover, ); // compute q_r - let q_r_left_prover = recursively_compute_final_check::<_, _, LRF, SRF>( + let q_r_left_prover = recursively_compute_final_check::<_, _, FRF, CRF>( v_from_left_prover.into_iter(), challenges_for_left_prover, self.q_mask_from_left_prover, @@ -242,11 +242,11 @@ impl BatchToVerify { where C: Context, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; - const LPL: usize = LargeProofGenerator::PROOF_LENGTH; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; let p_times_q_right = Self::compute_p_times_q( ctx.narrow(&Step::PTimesQ), @@ -257,7 +257,7 @@ impl BatchToVerify { .await?; // add Zero for p_times_q and sum since they are not secret shared - let diff_left = compute_g_differences::<_, SPL, SRF, LPL, LRF>( + let diff_left = compute_g_differences::<_, CPL, CRF, FPL, FRF>( &self.first_proof_from_left_prover, &self.proofs_from_left_prover, challenges_for_left_prover, @@ -265,7 +265,7 @@ impl BatchToVerify { Fp61BitPrime::ZERO, ); - let diff_right = compute_g_differences::<_, SPL, SRF, LPL, LRF>( + let diff_right = compute_g_differences::<_, CPL, CRF, FPL, FRF>( &self.first_proof_from_right_prover, &self.proofs_from_right_prover, challenges_for_right_prover, @@ -375,12 +375,12 @@ impl ProofHashes { const_assert_eq!( MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" + 14, + "following impl valid only for MAX_PROOF_RECURSION = 14" ); impl Serializable for [Hash; MAX_PROOF_RECURSION] { - type Size = U288; + type Size = U448; type DeserializationError = ::DeserializationError; @@ -409,14 +409,14 @@ impl MpcMessage for [Hash; MAX_PROOF_RECURSION] {} const_assert_eq!( MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" + 14, + "following impl valid only for MAX_PROOF_RECURSION = 14" ); type ProofDiff = [Fp61BitPrime; MAX_PROOF_RECURSION + 1]; impl Serializable for ProofDiff { - type Size = U80; + type Size = U120; type DeserializationError = ::DeserializationError; @@ -459,10 +459,10 @@ pub mod test { ipa_prf::{ malicious_security::{ lagrange::CanonicalLagrangeDenominator, - prover::{LargeProofGenerator, SmallProofGenerator}, verifier::{compute_sum_share, interpolate_at_r}, }, validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, + CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, @@ -484,7 +484,7 @@ pub mod test { // first proof has correct length assert_eq!( left_verifier.first_proof_from_left_prover.len(), - LargeProofGenerator::PROOF_LENGTH + FirstProofGenerator::PROOF_LENGTH ); assert_eq!( left_verifier.first_proof_from_left_prover.len(), @@ -494,7 +494,7 @@ pub mod test { for i in 0..left_verifier.proofs_from_left_prover.len() { assert_eq!( (i, left_verifier.proofs_from_left_prover[i].len()), - (i, SmallProofGenerator::PROOF_LENGTH) + (i, CompressedProofGenerator::PROOF_LENGTH) ); assert_eq!( (i, left_verifier.proofs_from_left_prover[i].len()), @@ -513,29 +513,29 @@ pub mod test { // check first proof, // compute simple proof without lagrange interpolated points let simple_proof = { - let block_to_polynomial = BLOCK_SIZE / LargeProofGenerator::RECURSION_FACTOR; + let block_to_polynomial = BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR; let simple_proof_uv = (0usize..100 * block_to_polynomial) .map(|i| { ( - (LargeProofGenerator::RECURSION_FACTOR * i - ..LargeProofGenerator::RECURSION_FACTOR * (i + 1)) + (FirstProofGenerator::RECURSION_FACTOR * i + ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>(), - (LargeProofGenerator::RECURSION_FACTOR * i - ..LargeProofGenerator::RECURSION_FACTOR * (i + 1)) + .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), + (FirstProofGenerator::RECURSION_FACTOR * i + ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>(), + .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), ) }) .collect::>(); simple_proof_uv.iter().fold( - [Fp61BitPrime::ZERO; LargeProofGenerator::RECURSION_FACTOR], + [Fp61BitPrime::ZERO; FirstProofGenerator::RECURSION_FACTOR], |mut acc, (left, right)| { - for i in 0..LargeProofGenerator::RECURSION_FACTOR { + for i in 0..FirstProofGenerator::RECURSION_FACTOR { acc[i] += left[i] * right[i]; } acc @@ -558,7 +558,7 @@ pub mod test { (h.as_u128(), simple_proof.to_vec()), ( h.as_u128(), - proof_computed[0..LargeProofGenerator::RECURSION_FACTOR].to_vec() + proof_computed[0..FirstProofGenerator::RECURSION_FACTOR].to_vec() ) ); } @@ -777,9 +777,9 @@ pub mod test { } fn assert_batch(left: &BatchToVerify, right: &BatchToVerify, challenges: &[Fp61BitPrime]) { - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; - const LPL: usize = LargeProofGenerator::PROOF_LENGTH; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; let first = recombine( &left.first_proof_from_left_prover, @@ -791,19 +791,19 @@ pub mod test { .zip(right.proofs_from_right_prover.iter()) .map(|(left, right)| recombine(left, right)) .collect::>(); - let denominator_first = CanonicalLagrangeDenominator::<_, LPL>::new(); - let denominator = CanonicalLagrangeDenominator::<_, SPL>::new(); + let denominator_first = CanonicalLagrangeDenominator::<_, FPL>::new(); + let denominator = CanonicalLagrangeDenominator::<_, CPL>::new(); let length = others.len(); let mut out = interpolate_at_r(&first, &challenges[0], &denominator_first); for (i, proof) in others.iter().take(length - 1).enumerate() { - assert_eq!((i, out), (i, compute_sum_share::<_, SRF, SPL>(proof))); + assert_eq!((i, out), (i, compute_sum_share::<_, CRF, CPL>(proof))); out = interpolate_at_r(proof, &challenges[i + 1], &denominator); } // last sum without masks let masks = others[length - 1][0]; - let last_sum = compute_sum_share::<_, SRF, SPL>(&others[length - 1]); + let last_sum = compute_sum_share::<_, CRF, CPL>(&others[length - 1]); assert_eq!(out, last_sum - masks); } @@ -869,7 +869,7 @@ pub mod test { let denominator = CanonicalLagrangeDenominator::< Fp61BitPrime, - { SmallProofGenerator::PROOF_LENGTH }, + { CompressedProofGenerator::PROOF_LENGTH }, >::new(); let g_r_left = interpolate_at_r( @@ -988,9 +988,15 @@ pub mod test { // Test a batch that exercises the case where `uv_values.len() == 1` but `did_set_masks = // false` in `ProofBatch::generate`. - verify_batch( - LargeProofGenerator::RECURSION_FACTOR * SmallProofGenerator::RECURSION_FACTOR - / BLOCK_SIZE, - ); + // + // We divide by `BLOCK_SIZE` here because `generate_u_v`, which is used by + // `verify_batch` to generate test data, generates `len` chunks of u/v values of + // length `BLOCK_SIZE`. We want the input u/v values to compress to exactly one + // u/v pair after some number of proof steps. + let num_inputs = FirstProofGenerator::RECURSION_FACTOR + * CompressedProofGenerator::RECURSION_FACTOR + * CompressedProofGenerator::RECURSION_FACTOR; + assert!(num_inputs % BLOCK_SIZE == 0); + verify_batch(num_inputs / BLOCK_SIZE); } } diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index e8dfd95ae..6829f57fa 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -4,4 +4,4 @@ pub mod arraychunks; mod power_of_two; #[cfg(target_pointer_width = "64")] -pub use power_of_two::NonZeroU32PowerOfTwo; +pub use power_of_two::{non_zero_prev_power_of_two, NonZeroU32PowerOfTwo}; diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs index a84455c92..b34ac0423 100644 --- a/ipa-core/src/utils/power_of_two.rs +++ b/ipa-core/src/utils/power_of_two.rs @@ -68,9 +68,19 @@ impl NonZeroU32PowerOfTwo { } } +/// Returns the largest power of two less than or equal to `target`. +/// +/// Returns 1 if `target` is zero. +pub fn non_zero_prev_power_of_two(target: usize) -> usize { + let bits = usize::BITS - target.leading_zeros(); + + 1 << (std::cmp::max(1, bits) - 1) +} + #[cfg(all(test, unit_test))] mod tests { use super::{ConvertError, NonZeroU32PowerOfTwo}; + use crate::utils::power_of_two::non_zero_prev_power_of_two; #[test] fn rejects_invalid_values() { @@ -107,4 +117,19 @@ mod tests { "3".parse::().unwrap_err() ); } + + #[test] + fn test_prev_power_of_two() { + const TWO_EXP_62: usize = 1usize << (usize::BITS - 2); + const TWO_EXP_63: usize = 1usize << (usize::BITS - 1); + assert_eq!(non_zero_prev_power_of_two(0), 1usize); + assert_eq!(non_zero_prev_power_of_two(1), 1usize); + assert_eq!(non_zero_prev_power_of_two(2), 2usize); + assert_eq!(non_zero_prev_power_of_two(3), 2usize); + assert_eq!(non_zero_prev_power_of_two(4), 4usize); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63 - 1), TWO_EXP_62); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63), TWO_EXP_63); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63 + 1), TWO_EXP_63); + assert_eq!(non_zero_prev_power_of_two(usize::MAX), TWO_EXP_63); + } }