From 1195f5ddb1e50a4002fc36bc82669d5ec64a86eb Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Wed, 11 Oct 2023 13:05:44 -0700 Subject: [PATCH] algebra: Add sub-modular structure to `algebra` --- Cargo.toml | 1 + benches/circuit_msm_throughput.rs | 4 +--- benches/gate_throughput.rs | 4 ++-- benches/gate_throughput_traced.rs | 4 ++-- benches/growable_buffer.rs | 2 +- benches/native_msm.rs | 2 +- integration/authenticated_curve.rs | 7 ++---- integration/authenticated_scalar.rs | 5 +--- integration/circuits.rs | 5 +--- integration/fabric.rs | 2 +- integration/helpers.rs | 5 ++-- integration/main.rs | 2 +- integration/mpc_curve.rs | 6 +---- integration/mpc_scalar.rs | 5 +--- .../{ => curve}/authenticated_curve.rs | 5 ++-- src/algebra/{ => curve}/curve.rs | 13 +++------- src/algebra/curve/mod.rs | 13 ++++++++++ src/algebra/{ => curve}/mpc_curve.rs | 12 ++++------ src/algebra/mod.rs | 15 ++++++------ src/algebra/poly/authenticated_poly.rs | 24 +++++++++++++++++++ src/algebra/poly/mod.rs | 7 ++++++ .../{ => scalar}/authenticated_scalar.rs | 4 +--- src/algebra/scalar/mod.rs | 13 ++++++++++ src/algebra/{ => scalar}/mpc_scalar.rs | 11 ++++----- src/algebra/{ => scalar}/scalar.rs | 3 +-- src/beaver.rs | 2 +- src/commitment.rs | 5 +--- src/fabric.rs | 9 +++---- src/fabric/result.rs | 2 +- src/lib.rs | 2 +- src/network.rs | 2 +- 31 files changed, 107 insertions(+), 89 deletions(-) rename src/algebra/{ => curve}/authenticated_curve.rs (99%) rename src/algebra/{ => curve}/curve.rs (98%) create mode 100644 src/algebra/curve/mod.rs rename src/algebra/{ => curve}/mpc_curve.rs (98%) create mode 100644 src/algebra/poly/authenticated_poly.rs create mode 100644 src/algebra/poly/mod.rs rename src/algebra/{ => scalar}/authenticated_scalar.rs (99%) create mode 100644 src/algebra/scalar/mod.rs rename src/algebra/{ => scalar}/mpc_scalar.rs (99%) rename src/algebra/{ => scalar}/scalar.rs (99%) diff --git a/Cargo.toml b/Cargo.toml index 6ee3f30..ae1a979 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ path = "src/lib.rs" benchmarks = [] stats = ["benchmarks"] test_helpers = ["ark-bn254"] +poly = [] [[test]] name = "integration" diff --git a/benches/circuit_msm_throughput.rs b/benches/circuit_msm_throughput.rs index a94db3f..aa58aab 100644 --- a/benches/circuit_msm_throughput.rs +++ b/benches/circuit_msm_throughput.rs @@ -2,9 +2,7 @@ use std::time::{Duration, Instant}; -use ark_mpc::{ - algebra::authenticated_curve::AuthenticatedPointResult, test_helpers::execute_mock_mpc, -}; +use ark_mpc::{algebra::AuthenticatedPointResult, test_helpers::execute_mock_mpc}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use itertools::Itertools; use tokio::runtime::Builder as RuntimeBuilder; diff --git a/benches/gate_throughput.rs b/benches/gate_throughput.rs index fde669e..036a765 100644 --- a/benches/gate_throughput.rs +++ b/benches/gate_throughput.rs @@ -1,8 +1,8 @@ use std::{path::Path, sync::Mutex}; use ark_mpc::{ - algebra::scalar::Scalar, beaver::PartyIDBeaverSource, network::NoRecvNetwork, - test_helpers::TestCurve, MpcFabric, PARTY0, + algebra::Scalar, beaver::PartyIDBeaverSource, network::NoRecvNetwork, test_helpers::TestCurve, + MpcFabric, PARTY0, }; use cpuprofiler::{Profiler as CpuProfiler, PROFILER}; use criterion::{ diff --git a/benches/gate_throughput_traced.rs b/benches/gate_throughput_traced.rs index b95d672..b88b944 100644 --- a/benches/gate_throughput_traced.rs +++ b/benches/gate_throughput_traced.rs @@ -4,8 +4,8 @@ use std::time::Instant; use ark_mpc::{ - algebra::scalar::Scalar, beaver::PartyIDBeaverSource, network::NoRecvNetwork, - test_helpers::TestCurve, MpcFabric, PARTY0, + algebra::Scalar, beaver::PartyIDBeaverSource, network::NoRecvNetwork, test_helpers::TestCurve, + MpcFabric, PARTY0, }; use clap::Parser; use cpuprofiler::PROFILER; diff --git a/benches/growable_buffer.rs b/benches/growable_buffer.rs index dde700f..56e443f 100644 --- a/benches/growable_buffer.rs +++ b/benches/growable_buffer.rs @@ -1,4 +1,4 @@ -use ark_mpc::{algebra::scalar::Scalar, buffer::GrowableBuffer, test_helpers::TestCurve}; +use ark_mpc::{algebra::Scalar, buffer::GrowableBuffer, test_helpers::TestCurve}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; // -------------- diff --git a/benches/native_msm.rs b/benches/native_msm.rs index 496cf26..fbfb4a5 100644 --- a/benches/native_msm.rs +++ b/benches/native_msm.rs @@ -1,7 +1,7 @@ //! Defines a benchmark for native multiscalar-multiplication on `Scalar` and `CurvePoint` types use ark_mpc::{ - algebra::{curve::CurvePoint, scalar::Scalar}, + algebra::{CurvePoint, Scalar}, random_point, test_helpers::TestCurve, }; diff --git a/integration/authenticated_curve.rs b/integration/authenticated_curve.rs index 8635c5b..71bb7a2 100644 --- a/integration/authenticated_curve.rs +++ b/integration/authenticated_curve.rs @@ -2,11 +2,8 @@ use ark_mpc::{ algebra::{ - authenticated_curve::{ - test_helpers::{modify_mac, modify_public_modifier, modify_share}, - AuthenticatedPointResult, - }, - scalar::Scalar, + curve_test_helpers::{modify_mac, modify_public_modifier, modify_share}, + AuthenticatedPointResult, Scalar, }, random_point, PARTY0, PARTY1, }; diff --git a/integration/authenticated_scalar.rs b/integration/authenticated_scalar.rs index 939717e..5016169 100644 --- a/integration/authenticated_scalar.rs +++ b/integration/authenticated_scalar.rs @@ -2,10 +2,7 @@ //! a malicious-secure primitive use ark_mpc::{ - algebra::{ - authenticated_scalar::{test_helpers::*, AuthenticatedScalarResult}, - scalar::Scalar, - }, + algebra::{scalar_test_helpers::*, AuthenticatedScalarResult, Scalar}, ResultValue, PARTY0, PARTY1, }; use itertools::Itertools; diff --git a/integration/circuits.rs b/integration/circuits.rs index 5a31bc6..3312f0b 100644 --- a/integration/circuits.rs +++ b/integration/circuits.rs @@ -1,10 +1,7 @@ //! Tests for more complicated operations (i.e. circuits) use ark_mpc::{ - algebra::{ - authenticated_curve::AuthenticatedPointResult, - authenticated_scalar::AuthenticatedScalarResult, scalar::Scalar, - }, + algebra::{AuthenticatedPointResult, AuthenticatedScalarResult, Scalar}, random_point, PARTY0, PARTY1, }; use itertools::Itertools; diff --git a/integration/fabric.rs b/integration/fabric.rs index ed955aa..bd5a86f 100644 --- a/integration/fabric.rs +++ b/integration/fabric.rs @@ -1,6 +1,6 @@ //! Defines tests for the fabric directly -use ark_mpc::{algebra::scalar::Scalar, PARTY0, PARTY1}; +use ark_mpc::{algebra::Scalar, PARTY0, PARTY1}; use crate::{ helpers::{assert_scalars_eq, await_result, share_scalar}, diff --git a/integration/helpers.rs b/integration/helpers.rs index e579cee..e305fc3 100644 --- a/integration/helpers.rs +++ b/integration/helpers.rs @@ -4,9 +4,8 @@ use std::fmt::Debug; use ark_mpc::{ algebra::{ - authenticated_curve::AuthenticatedPointResult, - authenticated_scalar::AuthenticatedScalarResult, mpc_curve::MpcPointResult, - mpc_scalar::MpcScalarResult, scalar::Scalar, + AuthenticatedPointResult, AuthenticatedScalarResult, MpcPointResult, MpcScalarResult, + Scalar, }, beaver::SharedValueSource, network::{NetworkPayload, PartyId}, diff --git a/integration/main.rs b/integration/main.rs index 21c2a03..b183696 100644 --- a/integration/main.rs +++ b/integration/main.rs @@ -2,7 +2,7 @@ use std::{borrow::Borrow, io::Write, net::SocketAddr, process::exit, thread, tim use ark_bn254::G1Projective as Bn254Projective; use ark_mpc::{ - algebra::{curve::CurvePoint, scalar::Scalar}, + algebra::{CurvePoint, Scalar}, network::{NetworkOutbound, NetworkPayload, QuicTwoPartyNet}, MpcFabric, PARTY0, }; diff --git a/integration/mpc_curve.rs b/integration/mpc_curve.rs index c5f1c6c..458beab 100644 --- a/integration/mpc_curve.rs +++ b/integration/mpc_curve.rs @@ -1,11 +1,7 @@ //! Defines tests for the `MpcPointResult` type and arithmetic on this type use ark_mpc::{ - algebra::{ - curve::CurvePointResult, - mpc_curve::MpcPointResult, - scalar::{Scalar, ScalarResult}, - }, + algebra::{CurvePointResult, MpcPointResult, Scalar, ScalarResult}, random_point, PARTY0, PARTY1, }; use itertools::Itertools; diff --git a/integration/mpc_scalar.rs b/integration/mpc_scalar.rs index 352c5ba..10fa673 100644 --- a/integration/mpc_scalar.rs +++ b/integration/mpc_scalar.rs @@ -1,9 +1,6 @@ //! Defines unit tests for `MpcScalarResult` types use ark_mpc::{ - algebra::{ - mpc_scalar::MpcScalarResult, - scalar::{Scalar, ScalarResult}, - }, + algebra::{MpcScalarResult, Scalar, ScalarResult}, PARTY0, PARTY1, }; use itertools::Itertools; diff --git a/src/algebra/authenticated_curve.rs b/src/algebra/curve/authenticated_curve.rs similarity index 99% rename from src/algebra/authenticated_curve.rs rename to src/algebra/curve/authenticated_curve.rs index ddef4d0..1f33b85 100644 --- a/src/algebra/authenticated_curve.rs +++ b/src/algebra/curve/authenticated_curve.rs @@ -14,6 +14,8 @@ use futures::{Future, FutureExt}; use itertools::{izip, Itertools}; use crate::{ + algebra::macros::*, + algebra::scalar::*, commitment::{HashCommitment, HashCommitmentResult}, error::MpcError, fabric::{MpcFabric, ResultValue}, @@ -21,11 +23,8 @@ use crate::{ }; use super::{ - authenticated_scalar::AuthenticatedScalarResult, curve::{BatchCurvePointResult, CurvePoint, CurvePointResult}, - macros::{impl_borrow_variants, impl_commutative}, mpc_curve::MpcPointResult, - scalar::{Scalar, ScalarResult}, }; /// The number of underlying results in an `AuthenticatedPointResult` diff --git a/src/algebra/curve.rs b/src/algebra/curve/curve.rs similarity index 98% rename from src/algebra/curve.rs rename to src/algebra/curve/curve.rs index eefe29d..17516a1 100644 --- a/src/algebra/curve.rs +++ b/src/algebra/curve/curve.rs @@ -24,20 +24,13 @@ use serde::{de::Error as DeError, Deserialize, Serialize}; use crate::{ algebra::{ - authenticated_curve::AUTHENTICATED_POINT_RESULT_LEN, - authenticated_scalar::AUTHENTICATED_SCALAR_RESULT_LEN, + macros::*, n_bytes_field, scalar::*, AUTHENTICATED_POINT_RESULT_LEN, + AUTHENTICATED_SCALAR_RESULT_LEN, }, fabric::{ResultHandle, ResultValue}, }; -use super::{ - authenticated_curve::AuthenticatedPointResult, - authenticated_scalar::AuthenticatedScalarResult, - macros::{impl_borrow_variants, impl_commutative}, - mpc_curve::MpcPointResult, - mpc_scalar::MpcScalarResult, - scalar::{n_bytes_field, Scalar, ScalarResult}, -}; +use super::{authenticated_curve::AuthenticatedPointResult, mpc_curve::MpcPointResult}; /// The number of points and scalars to pull from an iterated MSM when /// performing a multiscalar multiplication diff --git a/src/algebra/curve/mod.rs b/src/algebra/curve/mod.rs new file mode 100644 index 0000000..c45f811 --- /dev/null +++ b/src/algebra/curve/mod.rs @@ -0,0 +1,13 @@ +//! Defines curve types for shared authenticated, shared unauthenticated, and plaintext curve points +#![allow(clippy::module_inception)] + +mod authenticated_curve; +mod curve; +mod mpc_curve; + +pub use authenticated_curve::*; +pub use curve::*; +pub use mpc_curve::*; + +#[cfg(feature = "test_helpers")] +pub use authenticated_curve::test_helpers as curve_test_helpers; diff --git a/src/algebra/mpc_curve.rs b/src/algebra/curve/mpc_curve.rs similarity index 98% rename from src/algebra/mpc_curve.rs rename to src/algebra/curve/mpc_curve.rs index 773a80d..9e74d6e 100644 --- a/src/algebra/mpc_curve.rs +++ b/src/algebra/curve/mpc_curve.rs @@ -6,15 +6,13 @@ use std::ops::{Add, Mul, Neg, Sub}; use ark_ec::CurveGroup; use itertools::Itertools; -use crate::{fabric::ResultValue, network::NetworkPayload, MpcFabric, ResultId, PARTY0}; - -use super::{ - curve::{BatchCurvePointResult, CurvePoint, CurvePointResult}, - macros::{impl_borrow_variants, impl_commutative}, - mpc_scalar::MpcScalarResult, - scalar::{Scalar, ScalarResult}, +use crate::{ + algebra::macros::*, algebra::scalar::*, fabric::ResultValue, network::NetworkPayload, + MpcFabric, ResultId, PARTY0, }; +use super::curve::{BatchCurvePointResult, CurvePoint, CurvePointResult}; + /// Defines a secret shared type of a curve point #[derive(Clone, Debug)] pub struct MpcPointResult { diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index e4a9160..6520411 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -1,9 +1,10 @@ //! Defines algebraic MPC types and operations on them -pub mod authenticated_curve; -pub mod authenticated_scalar; -pub mod curve; -pub mod macros; -pub mod mpc_curve; -pub mod mpc_scalar; -pub mod scalar; +mod curve; +mod macros; +mod poly; +mod scalar; + +pub use curve::*; +pub use poly::*; +pub use scalar::*; diff --git a/src/algebra/poly/authenticated_poly.rs b/src/algebra/poly/authenticated_poly.rs new file mode 100644 index 0000000..6cc16b4 --- /dev/null +++ b/src/algebra/poly/authenticated_poly.rs @@ -0,0 +1,24 @@ +//! An authenticated polynomial over a `CurveGroup`'s scalar field +//! +//! Modeled after the `ark_poly::DensePolynomial` type, but allocated in an MPC fabric + +use ark_ec::CurveGroup; + +use crate::algebra::AuthenticatedScalarResult; + +/// An authenticated polynomial; i.e. a polynomial in which the coefficients are secret +/// shared between parties +/// +/// This is modeled after the `ark_poly::DensePolynomial` [source](https://github.com/arkworks-rs/algebra/blob/master/poly/src/polynomial/univariate/dense.rs#L22) +#[derive(Debug, Clone)] +pub struct AuthenticatedDensePoly { + /// A vector of coefficients, the coefficient for `x^i` is stored at index `i` + pub coeffs: Vec>, +} + +impl AuthenticatedDensePoly { + /// Constructor + pub fn from_coeffs(coeffs: Vec>) -> Self { + Self { coeffs } + } +} diff --git a/src/algebra/poly/mod.rs b/src/algebra/poly/mod.rs new file mode 100644 index 0000000..be70730 --- /dev/null +++ b/src/algebra/poly/mod.rs @@ -0,0 +1,7 @@ +//! Polynomial types over secret shared fields +//! +//! Modeled after the `ark_poly` implementation + +mod authenticated_poly; + +pub use authenticated_poly::*; diff --git a/src/algebra/authenticated_scalar.rs b/src/algebra/scalar/authenticated_scalar.rs similarity index 99% rename from src/algebra/authenticated_scalar.rs rename to src/algebra/scalar/authenticated_scalar.rs index 985c1a0..2ac27be 100644 --- a/src/algebra/authenticated_scalar.rs +++ b/src/algebra/scalar/authenticated_scalar.rs @@ -13,6 +13,7 @@ use futures::{Future, FutureExt}; use itertools::{izip, Itertools}; use crate::{ + algebra::{macros::*, AuthenticatedPointResult, CurvePoint, CurvePointResult}, commitment::{PedersenCommitment, PedersenCommitmentResult}, error::MpcError, fabric::{MpcFabric, ResultId, ResultValue}, @@ -20,9 +21,6 @@ use crate::{ }; use super::{ - authenticated_curve::AuthenticatedPointResult, - curve::{CurvePoint, CurvePointResult}, - macros::{impl_borrow_variants, impl_commutative}, mpc_scalar::MpcScalarResult, scalar::{BatchScalarResult, Scalar, ScalarResult}, }; diff --git a/src/algebra/scalar/mod.rs b/src/algebra/scalar/mod.rs new file mode 100644 index 0000000..494e7a1 --- /dev/null +++ b/src/algebra/scalar/mod.rs @@ -0,0 +1,13 @@ +//! Scalar type arithmetic with shared authenticated, shared non-authenticated, and plaintext types +#![allow(clippy::module_inception)] + +mod authenticated_scalar; +mod mpc_scalar; +mod scalar; + +pub use authenticated_scalar::*; +pub use mpc_scalar::*; +pub use scalar::*; + +#[cfg(feature = "test_helpers")] +pub use authenticated_scalar::test_helpers as scalar_test_helpers; diff --git a/src/algebra/mpc_scalar.rs b/src/algebra/scalar/mpc_scalar.rs similarity index 99% rename from src/algebra/mpc_scalar.rs rename to src/algebra/scalar/mpc_scalar.rs index 9937fce..3d7761d 100644 --- a/src/algebra/mpc_scalar.rs +++ b/src/algebra/scalar/mpc_scalar.rs @@ -7,18 +7,15 @@ use ark_ec::CurveGroup; use itertools::Itertools; use crate::{ - algebra::scalar::BatchScalarResult, + algebra::macros::*, + algebra::BatchScalarResult, + algebra::{CurvePoint, CurvePointResult, MpcPointResult}, fabric::{MpcFabric, ResultValue}, network::NetworkPayload, PARTY0, }; -use super::{ - curve::{CurvePoint, CurvePointResult}, - macros::{impl_borrow_variants, impl_commutative}, - mpc_curve::MpcPointResult, - scalar::{Scalar, ScalarResult}, -}; +use super::scalar::{Scalar, ScalarResult}; /// Defines a secret shared type over the `Scalar` field #[derive(Clone, Debug)] diff --git a/src/algebra/scalar.rs b/src/algebra/scalar/scalar.rs similarity index 99% rename from src/algebra/scalar.rs rename to src/algebra/scalar/scalar.rs index 6fd61a9..315e6ef 100644 --- a/src/algebra/scalar.rs +++ b/src/algebra/scalar/scalar.rs @@ -18,10 +18,9 @@ use num_bigint::BigUint; use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; +use crate::algebra::macros::*; use crate::fabric::{ResultHandle, ResultValue}; -use super::macros::{impl_borrow_variants, impl_commutative}; - // ----------- // | Helpers | // ----------- diff --git a/src/beaver.rs b/src/beaver.rs index 7dc6efd..e53e030 100644 --- a/src/beaver.rs +++ b/src/beaver.rs @@ -4,7 +4,7 @@ use ark_ec::CurveGroup; use itertools::Itertools; -use crate::algebra::scalar::Scalar; +use crate::algebra::Scalar; /// SharedValueSource implements both the functionality for: /// 1. Single additively shared values [x] where party 1 holds diff --git a/src/commitment.rs b/src/commitment.rs index f6c0aff..03b5b27 100644 --- a/src/commitment.rs +++ b/src/commitment.rs @@ -6,10 +6,7 @@ use rand::thread_rng; use sha3::{Digest, Sha3_256}; use crate::{ - algebra::{ - curve::{CurvePoint, CurvePointResult}, - scalar::{Scalar, ScalarResult}, - }, + algebra::{CurvePoint, CurvePointResult, Scalar, ScalarResult}, fabric::ResultValue, }; diff --git a/src/fabric.rs b/src/fabric.rs index fef35d4..9bdbb56 100644 --- a/src/fabric.rs +++ b/src/fabric.rs @@ -35,12 +35,9 @@ use itertools::Itertools; use crate::{ algebra::{ - authenticated_curve::AuthenticatedPointResult, - authenticated_scalar::AuthenticatedScalarResult, - curve::{BatchCurvePointResult, CurvePoint, CurvePointResult}, - mpc_curve::MpcPointResult, - mpc_scalar::MpcScalarResult, - scalar::{BatchScalarResult, Scalar, ScalarResult}, + AuthenticatedPointResult, AuthenticatedScalarResult, BatchCurvePointResult, + BatchScalarResult, CurvePoint, CurvePointResult, MpcPointResult, MpcScalarResult, Scalar, + ScalarResult, }, beaver::SharedValueSource, network::{MpcNetwork, NetworkOutbound, NetworkPayload, PartyId}, diff --git a/src/fabric/result.rs b/src/fabric/result.rs index 8d027bf..bbbf587 100644 --- a/src/fabric/result.rs +++ b/src/fabric/result.rs @@ -14,7 +14,7 @@ use ark_ec::CurveGroup; use futures::Future; use crate::{ - algebra::{curve::CurvePoint, scalar::Scalar}, + algebra::{CurvePoint, Scalar}, network::NetworkPayload, Shared, }; diff --git a/src/lib.rs b/src/lib.rs index b0fc54d..8e6f0c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ use std::sync::{Arc, RwLock}; -use algebra::{curve::CurvePoint, scalar::Scalar}; +use algebra::{CurvePoint, Scalar}; use ark_ec::CurveGroup; use rand::thread_rng; diff --git a/src/network.rs b/src/network.rs index 3ead927..a8d2c7e 100644 --- a/src/network.rs +++ b/src/network.rs @@ -17,7 +17,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::{ - algebra::{curve::CurvePoint, scalar::Scalar}, + algebra::{CurvePoint, Scalar}, error::MpcNetworkError, fabric::ResultId, };