Skip to content

Commit

Permalink
online-phase: algebra: Allocate single result for share and MAC
Browse files Browse the repository at this point in the history
This is done to allow the offline phase correclty supply all necessary
values for a given result.
  • Loading branch information
joeykraut committed Apr 15, 2024
1 parent a47604d commit acdb983
Show file tree
Hide file tree
Showing 13 changed files with 1,052 additions and 2,576 deletions.
652 changes: 236 additions & 416 deletions online-phase/src/algebra/curve/authenticated_curve.rs

Large diffs are not rendered by default.

157 changes: 39 additions & 118 deletions online-phase/src/algebra/curve/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ use itertools::Itertools;
use serde::{de::Error as DeError, Deserialize, Serialize};

use crate::{
algebra::{macros::*, scalar::*, ToBytes, AUTHENTICATED_POINT_RESULT_LEN},
algebra::{macros::*, scalar::*, PointShare, ToBytes},
fabric::{ResultHandle, ResultValue},
};

use super::{authenticated_curve::AuthenticatedPointResult, mpc_curve::MpcPointResult};
use super::authenticated_curve::AuthenticatedPointResult;

/// The number of points and scalars to pull from an iterated MSM when
/// performing a multi-scalar multiplication
Expand Down Expand Up @@ -478,41 +478,6 @@ impl<C: CurveGroup> CurvePointResult<C> {
})
}

/// Multiply a batch of `MpcScalarResult`s with a batch of
/// `CurvePointResult<C>`s
pub fn batch_mul_shared(
a: &[MpcScalarResult<C>],
b: &[CurvePointResult<C>],
) -> Vec<MpcPointResult<C>> {
assert_eq!(
a.len(),
b.len(),
"batch_mul_shared cannot compute on vectors of unequal length"
);

let n = a.len();
let fabric = a[0].fabric();

let lhs = a.iter().map(|r| r.id());
let rhs = b.iter().map(|r| r.id);
let all_ids = lhs.interleave(rhs).collect_vec();
fabric
.new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
let mut res = Vec::with_capacity(n);
for mut chunk in &args.chunks(2) {
let lhs: Scalar<C> = chunk.next().unwrap().into();
let rhs: CurvePoint<C> = chunk.next().unwrap().into();

res.push(ResultValue::Point(lhs * rhs));
}

res
})
.into_iter()
.map(MpcPointResult::from)
.collect_vec()
}

/// Multiply a batch of `AuthenticatedScalarResult`s with a batch of
/// `CurvePointResult<C>`s
pub fn batch_mul_authenticated(
Expand All @@ -528,34 +493,27 @@ impl<C: CurveGroup> CurvePointResult<C> {
let n = a.len();
let fabric = a[0].fabric();

let chunk_size = AUTHENTICATED_SCALAR_RESULT_LEN + 1;
let mut all_ids = Vec::with_capacity(n * chunk_size);
let mut all_ids = Vec::with_capacity(n * 2);
for (a, b) in a.iter().zip(b.iter()) {
all_ids.extend(a.ids());
all_ids.push(a.id());
all_ids.push(b.id);
}

let results = fabric.new_batch_gate_op(
fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_POINT_RESULT_LEN * n, // output_arity
n, // output_arity
move |args| {
let mut results = Vec::with_capacity(n * AUTHENTICATED_POINT_RESULT_LEN);
for mut chunk in &args.chunks(chunk_size) {
let share = Scalar::from(&chunk.next().unwrap());
let mac = Scalar::from(&chunk.next().unwrap());
let public_modifier = Scalar::from(&chunk.next().unwrap());
let mut results = Vec::with_capacity(n);
for mut chunk in &args.chunks(2) {
let share = ScalarShare::from(chunk.next().unwrap());
let point = CurvePoint::from(&chunk.next().unwrap());

results.push(ResultValue::Point(point * share));
results.push(ResultValue::Point(point * mac));
results.push(ResultValue::Point(point * public_modifier));
results.push(ResultValue::PointShare(point * share));
}

results
},
);

AuthenticatedPointResult::from_flattened_iterator(results.into_iter())
)
}
}

Expand Down Expand Up @@ -669,37 +627,18 @@ impl<C: CurveGroup> CurvePoint<C> {

// Clone points to let the gate closure take ownership
let points = points.to_vec();
let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
scalar_ids,
AUTHENTICATED_SCALAR_RESULT_LEN, // output_arity
move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
let mut modifiers = Vec::with_capacity(n);

for mut chunk in &args.map(Scalar::from).chunks(AUTHENTICATED_SCALAR_RESULT_LEN) {
shares.push(chunk.next().unwrap());
macs.push(chunk.next().unwrap());
modifiers.push(chunk.next().unwrap());
}

// Compute the MSM of the point
vec![
CurvePoint::msm(&shares, &points),
CurvePoint::msm(&macs, &points),
CurvePoint::msm(&modifiers, &points),
]
.into_iter()
.map(ResultValue::Point)
.collect_vec()
},
);
fabric.new_gate_op(scalar_ids, move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
for val in args.into_iter().map(ScalarShare::from) {
shares.push(val.share());
macs.push(val.mac());
}

AuthenticatedPointResult {
share: res[0].to_owned().into(),
mac: res[1].to_owned().into(),
public_modifier: res[2].to_owned(),
}
let share_msm = CurvePoint::msm(&shares, &points);
let mac_msm = CurvePoint::msm(&macs, &points);
ResultValue::PointShare(PointShare::new(share_msm, mac_msm))
})
}

/// Compute the multiscalar multiplication of the given authenticated
Expand Down Expand Up @@ -768,45 +707,27 @@ impl<C: CurveGroup> CurvePointResult<C> {
let n = scalars.len();
let fabric = scalars[0].fabric();

let chunk_size = AUTHENTICATED_SCALAR_RESULT_LEN + 1;
let mut all_ids = Vec::with_capacity(n * chunk_size);

let mut all_ids = Vec::with_capacity(n * 2);
for (a, b) in scalars.iter().zip(points.iter()) {
all_ids.extend(a.ids());
all_ids.push(b.id);
all_ids.push(a.id());
all_ids.push(b.id());
}

let res = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_POINT_RESULT_LEN, // output_arity
move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
let mut modifiers = Vec::with_capacity(n);
let mut points = Vec::with_capacity(n);
for mut chunk in &args.chunks(chunk_size) {
shares.push(chunk.next().unwrap().into());
macs.push(chunk.next().unwrap().into());
modifiers.push(chunk.next().unwrap().into());
points.push(chunk.next().unwrap().into());
}

vec![
CurvePoint::msm(&shares, &points),
CurvePoint::msm(&macs, &points),
CurvePoint::msm(&modifiers, &points),
]
.into_iter()
.map(ResultValue::Point)
.collect_vec()
},
);
fabric.new_gate_op(all_ids, move |args| {
let mut shares = Vec::with_capacity(n);
let mut macs = Vec::with_capacity(n);
let mut points = Vec::with_capacity(n);
for mut chunk in &args.chunks(2) {
let share: ScalarShare<C> = chunk.next().unwrap().into();
shares.push(share.share());
macs.push(share.mac());
points.push(chunk.next().unwrap().into());
}

AuthenticatedPointResult {
share: res[0].to_owned().into(),
mac: res[1].to_owned().into(),
public_modifier: res[2].to_owned(),
}
let share_msm = CurvePoint::msm(&shares, &points);
let mac_msm = CurvePoint::msm(&macs, &points);
ResultValue::PointShare(PointShare::new(share_msm, mac_msm))
})
}

/// Compute the multiscalar multiplication of the given
Expand Down
4 changes: 2 additions & 2 deletions online-phase/src/algebra/curve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

mod authenticated_curve;
mod curve;
mod mpc_curve;
mod share;

pub use authenticated_curve::*;
pub use curve::*;
pub use mpc_curve::*;
pub use share::*;

#[cfg(feature = "test_helpers")]
pub use authenticated_curve::test_helpers as curve_test_helpers;
Loading

0 comments on commit acdb983

Please sign in to comment.