Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

online-phase: algebra: Allocate single result for share and MAC #79

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
652 changes: 236 additions & 416 deletions online-phase/src/algebra/curve/authenticated_curve.rs

Large diffs are not rendered by default.

157 changes: 39 additions & 118 deletions online-phase/src/algebra/curve/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ use itertools::Itertools;
use serde::{de::Error as DeError, Deserialize, Serialize};

use crate::{
algebra::{macros::*, scalar::*, ToBytes, AUTHENTICATED_POINT_RESULT_LEN},
algebra::{macros::*, scalar::*, PointShare, ToBytes},
fabric::{ResultHandle, ResultValue},
};

use super::{authenticated_curve::AuthenticatedPointResult, mpc_curve::MpcPointResult};
use super::authenticated_curve::AuthenticatedPointResult;

/// The number of points and scalars to pull from an iterated MSM when
/// performing a multi-scalar multiplication
Expand Down Expand Up @@ -478,41 +478,6 @@ impl<C: CurveGroup> CurvePointResult<C> {
})
}

/// Multiply a batch of `MpcScalarResult`s with a batch of
/// `CurvePointResult<C>`s
pub fn batch_mul_shared(
a: &[MpcScalarResult<C>],
b: &[CurvePointResult<C>],
) -> Vec<MpcPointResult<C>> {
assert_eq!(
a.len(),
b.len(),
"batch_mul_shared cannot compute on vectors of unequal length"
);

let n = a.len();
let fabric = a[0].fabric();

let lhs = a.iter().map(|r| r.id());
let rhs = b.iter().map(|r| r.id);
let all_ids = lhs.interleave(rhs).collect_vec();
fabric
.new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
let mut res = Vec::with_capacity(n);
for mut chunk in &args.chunks(2) {
let lhs: Scalar<C> = chunk.next().unwrap().into();
let rhs: CurvePoint<C> = chunk.next().unwrap().into();

res.push(ResultValue::Point(lhs * rhs));
}

res
})
.into_iter()
.map(MpcPointResult::from)
.collect_vec()
}

/// Multiply a batch of `AuthenticatedScalarResult`s with a batch of
/// `CurvePointResult<C>`s
pub fn batch_mul_authenticated(
Expand All @@ -528,34 +493,27 @@ impl<C: CurveGroup> CurvePointResult<C> {
let n = a.len();
let fabric = a[0].fabric();

let chunk_size = AUTHENTICATED_SCALAR_RESULT_LEN + 1;
let mut all_ids = Vec::with_capacity(n * chunk_size);
let mut all_ids = Vec::with_capacity(n * 2);
for (a, b) in a.iter().zip(b.iter()) {
all_ids.extend(a.ids());
all_ids.push(a.id());
all_ids.push(b.id);
}

let results = fabric.new_batch_gate_op(
fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_POINT_RESULT_LEN * n, // output_arity
n, // output_arity
move |args| {
let mut results = Vec::with_capacity(n * AUTHENTICATED_POINT_RESULT_LEN);
for mut chunk in &args.chunks(chunk_size) {
let share = Scalar::from(&chunk.next().unwrap());
let mac = Scalar::from(&chunk.next().unwrap());
let public_modifier = Scalar::from(&chunk.next().unwrap());
let mut results = Vec::with_capacity(n);
for mut chunk in &args.chunks(2) {
let share = ScalarShare::from(chunk.next().unwrap());
let point = CurvePoint::from(&chunk.next().unwrap());

results.push(ResultValue::Point(point * share));
results.push(ResultValue::Point(point * mac));
results.push(ResultValue::Point(point * public_modifier));
results.push(ResultValue::PointShare(point * share));
}

results
},
);

AuthenticatedPointResult::from_flattened_iterator(results.into_iter())
)
}
}

Expand Down Expand Up @@ -669,37 +627,18 @@ impl<C: CurveGroup> CurvePoint<C> {

// Clone points to let the gate closure take ownership
let points = points.to_vec();
let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
scalar_ids,
AUTHENTICATED_SCALAR_RESULT_LEN, // output_arity
move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
let mut modifiers = Vec::with_capacity(n);

for mut chunk in &args.map(Scalar::from).chunks(AUTHENTICATED_SCALAR_RESULT_LEN) {
shares.push(chunk.next().unwrap());
macs.push(chunk.next().unwrap());
modifiers.push(chunk.next().unwrap());
}

// Compute the MSM of the point
vec![
CurvePoint::msm(&shares, &points),
CurvePoint::msm(&macs, &points),
CurvePoint::msm(&modifiers, &points),
]
.into_iter()
.map(ResultValue::Point)
.collect_vec()
},
);
fabric.new_gate_op(scalar_ids, move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
for val in args.into_iter().map(ScalarShare::from) {
shares.push(val.share());
macs.push(val.mac());
}

AuthenticatedPointResult {
share: res[0].to_owned().into(),
mac: res[1].to_owned().into(),
public_modifier: res[2].to_owned(),
}
let share_msm = CurvePoint::msm(&shares, &points);
let mac_msm = CurvePoint::msm(&macs, &points);
ResultValue::PointShare(PointShare::new(share_msm, mac_msm))
})
}

/// Compute the multiscalar multiplication of the given authenticated
Expand Down Expand Up @@ -768,45 +707,27 @@ impl<C: CurveGroup> CurvePointResult<C> {
let n = scalars.len();
let fabric = scalars[0].fabric();

let chunk_size = AUTHENTICATED_SCALAR_RESULT_LEN + 1;
let mut all_ids = Vec::with_capacity(n * chunk_size);

let mut all_ids = Vec::with_capacity(n * 2);
for (a, b) in scalars.iter().zip(points.iter()) {
all_ids.extend(a.ids());
all_ids.push(b.id);
all_ids.push(a.id());
all_ids.push(b.id());
}

let res = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_POINT_RESULT_LEN, // output_arity
move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
let mut modifiers = Vec::with_capacity(n);
let mut points = Vec::with_capacity(n);
for mut chunk in &args.chunks(chunk_size) {
shares.push(chunk.next().unwrap().into());
macs.push(chunk.next().unwrap().into());
modifiers.push(chunk.next().unwrap().into());
points.push(chunk.next().unwrap().into());
}

vec![
CurvePoint::msm(&shares, &points),
CurvePoint::msm(&macs, &points),
CurvePoint::msm(&modifiers, &points),
]
.into_iter()
.map(ResultValue::Point)
.collect_vec()
},
);
fabric.new_gate_op(all_ids, move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
let mut points = Vec::with_capacity(n);
for mut chunk in &args.chunks(2) {
let share: ScalarShare<C> = chunk.next().unwrap().into();
shares.push(share.share());
macs.push(share.mac());
points.push(chunk.next().unwrap().into());
}

AuthenticatedPointResult {
share: res[0].to_owned().into(),
mac: res[1].to_owned().into(),
public_modifier: res[2].to_owned(),
}
let share_msm = CurvePoint::msm(&shares, &points);
let mac_msm = CurvePoint::msm(&macs, &points);
ResultValue::PointShare(PointShare::new(share_msm, mac_msm))
})
}

/// Compute the multiscalar multiplication of the given
Expand Down
4 changes: 2 additions & 2 deletions online-phase/src/algebra/curve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

mod authenticated_curve;
mod curve;
mod mpc_curve;
mod share;

pub use authenticated_curve::*;
pub use curve::*;
pub use mpc_curve::*;
pub use share::*;

#[cfg(feature = "test_helpers")]
pub use authenticated_curve::test_helpers as curve_test_helpers;
Loading
Loading