Skip to content

Commit

Permalink
Add function to shuffle vector with a seed.
Browse files Browse the repository at this point in the history
  • Loading branch information
ple13 committed Jan 9, 2024
1 parent 6345767 commit a1650d0
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 9 deletions.
17 changes: 17 additions & 0 deletions src/main/cc/any_sketch/crypto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,20 @@ cc_binary(
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "shuffle",
srcs = [
"shuffle.cc",
],
hdrs = [
"shuffle.h",
],
strip_include_prefix = _INCLUDE_PREFIX,
deps = [
"//src/main/cc/math:open_ssl_uniform_random_generator",
"//src/main/proto/wfa/any_sketch:secret_share_cc_proto",
"@com_google_absl//absl/status:status",
"@wfa_common_cpp//src/main/cc/common_cpp/macros",
],
)
65 changes: 65 additions & 0 deletions src/main/cc/any_sketch/crypto/shuffle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "any_sketch/crypto/shuffle.h"

#include "absl/status/status.h"
#include "common_cpp/macros/macros.h"
#include "math/open_ssl_uniform_random_generator.h"

namespace wfa::measurement::common::crypto {

using math::CreatePrngFromSeed;
using math::OpenSslUniformPseudorandomGenerator;
using math::UniformPseudorandomGenerator;

// Shuffles the vector data using Fisher-Yates approach. Let n be the size of
// data, the Fisher-Yates shuffle is as below.
// For i = (n-1) to 1:
// Draws a random value in the range [0; i]
// Swaps data[i] and data[j]
absl::Status ShuffleWithSeed(std::vector<uint32_t>& data,
const PrngSeed& seed) {
// Does nothing if the input is empty or has size 1.
if (data.size() <= 1) {
return absl::OkStatus();
}

// Initializes the pseudorandom generator using the provided seed.
ASSIGN_OR_RETURN(std::unique_ptr<UniformPseudorandomGenerator> prng,
CreatePrngFromSeed(seed));

// Samples random values.
ASSIGN_OR_RETURN(
std::vector<unsigned char> arr,
prng->GeneratePseudorandomBytes(data.size() * sizeof(unsigned __int128)));

unsigned __int128* rand = (unsigned __int128*)arr.data();
for (uint64_t i = data.size() - 1; i >= 1; i--) {
// Ideally, to make sure that the sampled permutation is not biased, rand[i]
// needs to be re-sampled if rand[i] >= 2^128 - (2^128 % (i+1)). However,
// the probability that this happens with any i in [1; data.size() - 1] is
// less than (data.size())^2/2^{128}, which is less than 2^{-40} for any
// input vector of size less than 2^{43}.
int index = rand[i] % (i + 1);
// Swaps the element at current position with the one at position index.
uint64_t temp = data[i];
data[i] = data[index];
data[index] = temp;
}

return absl::OkStatus();
}

} // namespace wfa::measurement::common::crypto
34 changes: 34 additions & 0 deletions src/main/cc/any_sketch/crypto/shuffle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_
#define SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_

#include <string>
#include <utility>

#include "absl/status/status.h"
#include "math/open_ssl_uniform_random_generator.h"
#include "wfa/any_sketch/secret_share.pb.h"

namespace wfa::measurement::common::crypto {

using any_sketch::PrngSeed;
using wfa::math::UniformPseudorandomGenerator;

absl::Status ShuffleWithSeed(std::vector<uint32_t>& data, const PrngSeed& seed);

} // namespace wfa::measurement::common::crypto

#endif // SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_
1 change: 1 addition & 0 deletions src/main/cc/math/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cc_library(
strip_include_prefix = _INCLUDE_PREFIX,
deps = [
":uniform_pseudorandom_generator",
"//src/main/proto/wfa/any_sketch:secret_share_cc_proto",
"@boringssl//:ssl",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
10 changes: 10 additions & 0 deletions src/main/cc/math/open_ssl_uniform_random_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,14 @@ OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange(
return ret;
}

absl::StatusOr<std::unique_ptr<UniformPseudorandomGenerator>>
CreatePrngFromSeed(const PrngSeed &seed) {
std::vector<unsigned char> key(seed.key().begin(), seed.key().end());
std::vector<unsigned char> iv(seed.iv().begin(), seed.iv().end());

ASSIGN_OR_RETURN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));
return prng;
}

} // namespace wfa::math
8 changes: 8 additions & 0 deletions src/main/cc/math/open_ssl_uniform_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "math/uniform_pseudorandom_generator.h"
#include "wfa/any_sketch/secret_share.pb.h"

namespace wfa::math {

using any_sketch::PrngSeed;

// Key length for EVP_aes_256_ctr.
// See https://www.openssl.org/docs/man1.1.1/man3/EVP_aes_256_ctr.html
inline constexpr int kBytesPerAes256Key = 32;
Expand Down Expand Up @@ -98,6 +101,11 @@ class OpenSslUniformPseudorandomGenerator
EVP_CIPHER_CTX* ctx_;
};

// Create a pseudorandom generator with a PrngSeed which is used to initialize
// the AES 256 counter mode.
absl::StatusOr<std::unique_ptr<UniformPseudorandomGenerator>>
CreatePrngFromSeed(const PrngSeed& seed);

} // namespace wfa::math

#endif // SRC_MAIN_CC_MATH_OPEN_SSL_UNIFORM_RANDOM_GENERATOR_H_
14 changes: 7 additions & 7 deletions src/main/proto/wfa/any_sketch/secret_share.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ message SecretShareParameter {
uint32 modulus = 1;
}

message SecretShare {
// Seed to initialize the AES 256 counter mode.
message ShareSeed {
bytes key = 1;
bytes iv = 2;
}
// Seed to initialize the AES 256 counter mode.
message PrngSeed {
bytes key = 1;
bytes iv = 2;
}

ShareSeed share_seed = 1;
message SecretShare {
PrngSeed share_seed = 1;
repeated uint32 share_vector = 2;
}
14 changes: 14 additions & 0 deletions src/test/cc/any_sketch/crypto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,17 @@ cc_test(
"@wfa_common_cpp//src/main/cc/common_cpp/testing:status",
],
)

cc_test(
name = "shuffle_test",
size = "small",
srcs = [
"shuffle_test.cc",
],
deps = [
"//src/main/cc/any_sketch/crypto:shuffle",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@wfa_common_cpp//src/main/cc/common_cpp/testing:status",
],
)
93 changes: 93 additions & 0 deletions src/test/cc/any_sketch/crypto/shuffle_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "any_sketch/crypto/shuffle.h"

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "common_cpp/testing/status_macros.h"
#include "common_cpp/testing/status_matchers.h"
#include "gtest/gtest.h"

namespace wfa::any_sketch::crypto {
namespace {

using measurement::common::crypto::PrngSeed;
using measurement::common::crypto::ShuffleWithSeed;
using ::wfa::StatusIs;
using ::wfa::math::kBytesPerAes256Iv;
using ::wfa::math::kBytesPerAes256Key;

TEST(ShuffleWithSeed, EmptyInputSucceeds) {
PrngSeed seed;
std::vector<uint32_t> data;
absl::Status ret = ShuffleWithSeed(data, seed);
ASSERT_EQ(ret, absl::OkStatus());
ASSERT_EQ(data.size(), 0);
}

TEST(ShuffleWithSeed, InputHasOneElementSucceeds) {
PrngSeed seed;
std::vector<uint32_t> data = {1};
absl::Status ret = ShuffleWithSeed(data, seed);
ASSERT_EQ(ret, absl::OkStatus());
ASSERT_EQ(data.size(), 1);
EXPECT_EQ(data[0], 1);
}

TEST(ShuffleWithSeed, InputSizeAtLeastTwoInvalidSeedFails) {
PrngSeed seed;
*seed.mutable_key() = std::string(kBytesPerAes256Key - 1, 'a');
*seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b');

std::vector<uint32_t> data = {1, 2};
absl::Status ret = ShuffleWithSeed(data, seed);
EXPECT_THAT(ret, StatusIs(absl::StatusCode::kInvalidArgument,
absl::Substitute(
"The uniform pseudorandom generator key has "
"length of $0 bytes but $1 bytes are required.",
seed.key().size(), kBytesPerAes256Key)));
}

TEST(ShuffleWithSeed, ShufflingSucceeds) {
PrngSeed seed;
*seed.mutable_key() = std::string(kBytesPerAes256Key, 'a');
*seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b');

int kInputSize = 100;
std::vector<uint32_t> data(kInputSize);

for (int i = 0; i < kInputSize; i++) {
data[i] = i;
}
std::vector<uint32_t> input = data;
absl::Status ret = ShuffleWithSeed(data, seed);
ASSERT_EQ(ret, absl::OkStatus());
ASSERT_EQ(data.size(), kInputSize);

// Verifies that the output array is different from the input array.
// With a random seed, there is a negligible chance of 1/(kInputsize!) that
// the permutation does not modify the original array and causes this check to
// fail.
EXPECT_NE(input, data);

// Verifies that the input array and the output array have the same elements.
std::sort(data.begin(), data.end());
for (int i = 0; i < kInputSize; i++) {
EXPECT_EQ(data[i], i);
}
}

} // namespace
} // namespace wfa::any_sketch::crypto
59 changes: 57 additions & 2 deletions src/test/cc/math/open_ssl_uniform_random_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace wfa::math {
namespace {

TEST(OpenSslUniformPseudorandomGenerator,
CreateTheGeneratorWithValidkeyAndIVSucceeds) {
CreateTheGeneratorWithValidKeyAndIVSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
Expand All @@ -35,7 +35,7 @@ TEST(OpenSslUniformPseudorandomGenerator,
}

TEST(OpenSslUniformPseudorandomGenerator,
CreateTheGeneratorWithInvalidkeySizeFails) {
CreateTheGeneratorWithInvalidKeySizeFails) {
std::vector<unsigned char> key(kBytesPerAes256Key - 1);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
Expand Down Expand Up @@ -66,6 +66,61 @@ TEST(OpenSslUniformPseudorandomGenerator,
iv.size(), kBytesPerAes256Iv)));
}

TEST(OpenSslUniformPseudorandomGenerator,
CreateTheGeneratorWithValidSeedSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

PrngSeed seed;
*seed.mutable_key() = std::string(key.begin(), key.end());
*seed.mutable_iv() = std::string(iv.begin(), iv.end());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
CreatePrngFromSeed(seed));
}

TEST(OpenSslUniformPseudorandomGenerator,
CreateTheGeneratorFromSeedWithInvalidKeyFails) {
std::vector<unsigned char> key(kBytesPerAes256Key - 1);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

PrngSeed seed;
*seed.mutable_key() = std::string(key.begin(), key.end());
*seed.mutable_iv() = std::string(iv.begin(), iv.end());

auto prng = CreatePrngFromSeed(seed);
EXPECT_THAT(
prng.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
absl::Substitute("The uniform pseudorandom generator key has "
"length of $0 bytes but $1 bytes are required.",
key.size(), kBytesPerAes256Key)));
}

TEST(OpenSslUniformPseudorandomGenerator,
CreateTheGeneratorFromSeedWithInvalidIVFails) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv - 1);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

PrngSeed seed;
*seed.mutable_key() = std::string(key.begin(), key.end());
*seed.mutable_iv() = std::string(iv.begin(), iv.end());

auto prng = CreatePrngFromSeed(seed);
EXPECT_THAT(
prng.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
absl::Substitute("The uniform pseudorandom generator IV has "
"length of $0 bytes but $1 bytes are required.",
iv.size(), kBytesPerAes256Iv)));
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingNonPositiveNumberOfRandomBytesFails) {
std::vector<unsigned char> key(kBytesPerAes256Key);
Expand Down

0 comments on commit a1650d0

Please sign in to comment.