Skip to content

Commit

Permalink
sta-rs: MessageGenerator epoch represented by u8 slice instead of string
Browse files Browse the repository at this point in the history
  • Loading branch information
DJAndries committed Nov 25, 2023
1 parent 49a87db commit 8a8154e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 31 deletions.
2 changes: 1 addition & 1 deletion star-wasm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "star-wasm"
version = "0.2.0"
version = "0.2.1"
authors = ["Rémi Berson <[email protected]>"]
description = "WASM bindings for the STAR protocol"
repository = "https://github.com/brave/sta-rs"
Expand Down
2 changes: 1 addition & 1 deletion star-wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn create_share(measurement: &[u8], threshold: u32, epoch: &str) -> String {
let mg = MessageGenerator::new(
SingleMeasurement::new(measurement),
threshold,
epoch,
epoch.as_bytes(),
);
let share_result = mg.share_with_local_randomness();
if share_result.is_err() {
Expand Down
2 changes: 1 addition & 1 deletion star/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sta-rs"
version = "0.2.2"
version = "0.3.0"
authors = ["Alex Davidson <[email protected]>"]
description = "Distributed Secret-Sharing for Threshold Aggregation Reporting"
documentation = "https://docs.rs/sta-rs"
Expand Down
10 changes: 5 additions & 5 deletions star/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn criterion_benchmark(c: &mut Criterion) {

fn benchmark_client_randomness_sampling(c: &mut Criterion) {
c.bench_function("Client local randomness", |b| {
let client = client_zipf(10000, 1.03, 2, "t");
let client = client_zipf(10000, 1.03, 2, b"t");
let mut out = vec![0u8; 32];
b.iter(|| {
client.sample_local_randomness(&mut out);
Expand All @@ -38,7 +38,7 @@ fn benchmark_client_randomness_sampling(c: &mut Criterion) {

fn benchmark_client_triple_generation(c: &mut Criterion) {
c.bench_function("Client generate triple (local)", |b| {
let mg = client_zipf(10000, 1.03, 2, "t");
let mg = client_zipf(10000, 1.03, 2, b"t");
let mut rnd = [0u8; 32];
mg.sample_local_randomness(&mut rnd);
b.iter(|| Message::generate(&mg, &rnd, None).unwrap());
Expand All @@ -57,7 +57,7 @@ fn benchmark_client_triple_generation(c: &mut Criterion) {

c.bench_function("Client generate triple (local, aux)", |b| {
let random_bytes = rand::thread_rng().gen::<[u8; 32]>();
let mg = client_zipf(10000, 1.03, 2, "t");
let mg = client_zipf(10000, 1.03, 2, b"t");
let mut rnd = [0u8; 32];
mg.sample_local_randomness(&mut rnd);
b.iter(|| {
Expand All @@ -81,7 +81,7 @@ fn benchmark_client_triple_generation(c: &mut Criterion) {
}

fn benchmark_server_retrieval(c: &mut Criterion) {
let mg = client_zipf(10000, 1.03, 50, "t");
let mg = client_zipf(10000, 1.03, 50, b"t");
let mut rnd = [0u8; 32];
mg.sample_local_randomness(&mut rnd);
let messages: Vec<Message> =
Expand Down Expand Up @@ -128,7 +128,7 @@ fn benchmark_end_to_end(c: &mut Criterion) {
}

fn get_messages(params: &Params, epoch: &str) -> Vec<Message> {
let mg = client_zipf(params.n, params.s, params.threshold, epoch);
let mg = client_zipf(params.n, params.s, params.threshold, epoch.as_bytes());
let mut rnd = [0u8; 32];
if params.local {
iter::repeat_with(|| {
Expand Down
18 changes: 9 additions & 9 deletions star/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
//! # let threshold = 2;
//! # let epoch = "t";
//! let measurement = SingleMeasurement::new("hello world".as_bytes());
//! let mg = MessageGenerator::new(measurement, threshold, epoch);
//! let mg = MessageGenerator::new(measurement, threshold, epoch.as_bytes());
//! let mut rnd = [0u8; 32];
//! // NOTE: this is for STARLite. Randomness must be sampled from a
//! // randomness server in order to implement the full STAR protocol.
Expand Down Expand Up @@ -67,7 +67,7 @@
//! # let threshold = 2;
//! # let epoch = "t";
//! let measurement = SingleMeasurement::new("hello world".as_bytes());
//! let mg = MessageGenerator::new(measurement, threshold, epoch);
//! let mg = MessageGenerator::new(measurement, threshold, epoch.as_bytes());
//! let mut rnd = [0u8; 32];
//! // NOTE: this is for STARLite. Randomness must be sampled from a
//! // randomness server in order to implement the full STAR protocol.
Expand All @@ -92,7 +92,7 @@
//! # let epoch = "t";
//! # let measurement = SingleMeasurement::new("hello world".as_bytes());
//!
//! # let mg = MessageGenerator::new(measurement, threshold, epoch);
//! # let mg = MessageGenerator::new(measurement, threshold, epoch.as_bytes());
//! # for i in 0..3 {
//! # let mut rnd = [0u8; 32];
//! # mg.sample_local_randomness(&mut rnd);
Expand Down Expand Up @@ -364,14 +364,14 @@ pub struct WASMSharingMaterial {
pub struct MessageGenerator {
pub x: SingleMeasurement,
threshold: u32,
epoch: String,
epoch: Vec<u8>,
}
impl MessageGenerator {
pub fn new(x: SingleMeasurement, threshold: u32, epoch: &str) -> Self {
pub fn new(x: SingleMeasurement, threshold: u32, epoch: &[u8]) -> Self {
Self {
x,
threshold,
epoch: epoch.to_string(),
epoch: epoch.into(),
}
}

Expand Down Expand Up @@ -426,7 +426,7 @@ impl MessageGenerator {

fn derive_key(&self, r1: &[u8]) -> [u8; 16] {
let mut enc_key = [0u8; 16];
derive_ske_key(r1, self.epoch.as_bytes(), &mut enc_key);
derive_ske_key(r1, &self.epoch, &mut enc_key);
enc_key
}

Expand All @@ -445,7 +445,7 @@ impl MessageGenerator {
}
strobe_digest(
self.x.as_slice(),
&[self.epoch.as_bytes(), &self.threshold.to_le_bytes()],
&[&self.epoch, &self.threshold.to_le_bytes()],
"star_sample_local",
out,
);
Expand All @@ -458,7 +458,7 @@ impl MessageGenerator {
out: &mut [u8],
) {
let mds = oprf_server.get_valid_metadata_tags();
let index = mds.iter().position(|r| r == self.epoch.as_bytes()).unwrap();
let index = mds.iter().position(|r| r == &self.epoch).unwrap();
end_to_end_evaluation(oprf_server, self.x.as_slice(), index, true, out);
}
}
Expand Down
2 changes: 1 addition & 1 deletion star/test-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub fn client_zipf(
n: usize,
s: f64,
threshold: u32,
epoch: &str,
epoch: &[u8],
) -> MessageGenerator {
let x = measurement_zipf(n, s);
MessageGenerator::new(x, threshold, epoch)
Expand Down
38 changes: 25 additions & 13 deletions star/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ pub struct PPOPRFServer;

#[test]
fn serialize_ciphertext() {
let mg = MessageGenerator::new(SingleMeasurement::new(b"foobar"), 0, "epoch");
let mg = MessageGenerator::new(
SingleMeasurement::new(b"foobar"),
0,
"epoch".as_bytes(),
);
let mut rnd = [0u8; 32];
mg.sample_local_randomness(&mut rnd);
let triple = Message::generate(&mg, &rnd, None)
Expand All @@ -21,7 +25,11 @@ fn serialize_ciphertext() {

#[test]
fn serialize_triple() {
let mg = MessageGenerator::new(SingleMeasurement::new(b"foobar"), 0, "epoch");
let mg = MessageGenerator::new(
SingleMeasurement::new(b"foobar"),
0,
"epoch".as_bytes(),
);
let mut rnd = [0u8; 32];
mg.sample_local_randomness(&mut rnd);
let triple = Message::generate(&mg, &rnd, None)
Expand All @@ -32,7 +40,11 @@ fn serialize_triple() {

#[test]
fn roundtrip() {
let mg = MessageGenerator::new(SingleMeasurement::new(b"foobar"), 1, "epoch");
let mg = MessageGenerator::new(
SingleMeasurement::new(b"foobar"),
1,
"epoch".as_bytes(),
);
let mut rnd = [0u8; 32];
mg.sample_local_randomness(&mut rnd);
let triple = Message::generate(&mg, &rnd, None)
Expand Down Expand Up @@ -116,19 +128,19 @@ fn star_no_aux_multiple_block(oprf_server: Option<PPOPRFServer>) {
clients.push(MessageGenerator::new(
SingleMeasurement::new(str1.as_bytes()),
threshold,
epoch,
epoch.as_bytes(),
));
} else if i % 4 == 0 {
clients.push(MessageGenerator::new(
SingleMeasurement::new(str2.as_bytes()),
threshold,
epoch,
epoch.as_bytes(),
));
} else {
clients.push(MessageGenerator::new(
SingleMeasurement::new(&[i as u8]),
threshold,
epoch,
epoch.as_bytes(),
));
}
}
Expand Down Expand Up @@ -177,19 +189,19 @@ fn star_no_aux_single_block(oprf_server: Option<PPOPRFServer>) {
clients.push(MessageGenerator::new(
SingleMeasurement::new(str1.as_bytes()),
threshold,
epoch,
epoch.as_bytes(),
));
} else if i % 4 == 0 {
clients.push(MessageGenerator::new(
SingleMeasurement::new(str2.as_bytes()),
threshold,
epoch,
epoch.as_bytes(),
));
} else {
clients.push(MessageGenerator::new(
SingleMeasurement::new(&[i as u8]),
threshold,
epoch,
epoch.as_bytes(),
));
}
}
Expand Down Expand Up @@ -239,19 +251,19 @@ fn star_with_aux_multiple_block(oprf_server: Option<PPOPRFServer>) {
clients.push(MessageGenerator::new(
SingleMeasurement::new(str1.as_bytes()),
threshold,
epoch,
epoch.as_bytes(),
));
} else if i % 4 == 0 {
clients.push(MessageGenerator::new(
SingleMeasurement::new(str2.as_bytes()),
threshold,
epoch,
epoch.as_bytes(),
));
} else {
clients.push(MessageGenerator::new(
SingleMeasurement::new(&[i as u8]),
threshold,
epoch,
epoch.as_bytes(),
));
}
}
Expand Down Expand Up @@ -313,7 +325,7 @@ fn star_rand_with_aux_multiple_block(oprf_server: Option<PPOPRFServer>) {
let threshold = 5;
let epoch = "t";
for _ in 0..254 {
clients.push(client_zipf(1000, 1.03, threshold, epoch));
clients.push(client_zipf(1000, 1.03, threshold, epoch.as_bytes()));
}
let agg_server = AggregationServer::new(threshold, epoch);

Expand Down

0 comments on commit 8a8154e

Please sign in to comment.