From 966b73282e680029678d0769edebabf5a9c7fe71 Mon Sep 17 00:00:00 2001 From: Phi Date: Fri, 19 Jan 2024 13:17:13 -0500 Subject: [PATCH] Add the function to shuffle a vector with a seed. (#39) --- src/main/cc/any_sketch/crypto/BUILD.bazel | 17 +++ src/main/cc/any_sketch/crypto/shuffle.cc | 64 ++++++++++ src/main/cc/any_sketch/crypto/shuffle.h | 37 ++++++ src/main/cc/math/BUILD.bazel | 1 + .../math/open_ssl_uniform_random_generator.cc | 12 ++ .../math/open_ssl_uniform_random_generator.h | 8 ++ .../proto/wfa/any_sketch/secret_share.proto | 14 +- src/test/cc/any_sketch/crypto/BUILD.bazel | 14 ++ src/test/cc/any_sketch/crypto/shuffle_test.cc | 120 ++++++++++++++++++ .../open_ssl_uniform_random_generator_test.cc | 59 ++++++++- 10 files changed, 337 insertions(+), 9 deletions(-) create mode 100644 src/main/cc/any_sketch/crypto/shuffle.cc create mode 100644 src/main/cc/any_sketch/crypto/shuffle.h create mode 100644 src/test/cc/any_sketch/crypto/shuffle_test.cc diff --git a/src/main/cc/any_sketch/crypto/BUILD.bazel b/src/main/cc/any_sketch/crypto/BUILD.bazel index 7786bb8..366b400 100644 --- a/src/main/cc/any_sketch/crypto/BUILD.bazel +++ b/src/main/cc/any_sketch/crypto/BUILD.bazel @@ -54,3 +54,20 @@ cc_binary( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "shuffle", + srcs = [ + "shuffle.cc", + ], + hdrs = [ + "shuffle.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + "//src/main/cc/math:open_ssl_uniform_random_generator", + "//src/main/proto/wfa/any_sketch:secret_share_cc_proto", + "@com_google_absl//absl/status", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", + ], +) diff --git a/src/main/cc/any_sketch/crypto/shuffle.cc b/src/main/cc/any_sketch/crypto/shuffle.cc new file mode 100644 index 0000000..986081c --- /dev/null +++ b/src/main/cc/any_sketch/crypto/shuffle.cc @@ -0,0 +1,64 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "any_sketch/crypto/shuffle.h" + +#include + +#include "absl/status/status.h" +#include "common_cpp/macros/macros.h" +#include "math/open_ssl_uniform_random_generator.h" + +namespace wfa::measurement::common::crypto { + +absl::Status SecureShuffleWithSeed(std::vector& data, + const any_sketch::PrngSeed& seed) { + // Does nothing if the input is empty or has size 1. + if (data.size() <= 1) { + return absl::OkStatus(); + } + + // Initializes the pseudorandom generator using the provided seed. + ASSIGN_OR_RETURN(std::unique_ptr prng, + math::CreatePrngFromSeed(seed)); + + // The custom implementation of Fisher-Yates shuffle is as below. It is not + // recommended to use std::shuffle because the implementation of std::shuffle + // is not dictated by the standard, even if an exactly same + // UniformRandomBitGenertor is used, different results with different standard + // library implementations could happen. + + // Samples all the random values that will be used to compute all the swapping + // indices. + ASSIGN_OR_RETURN( + std::vector arr, + prng->GeneratePseudorandomBytes(data.size() * sizeof(absl::uint128))); + + absl::uint128* rand = (absl::uint128*)arr.data(); + int64_t num_elements = data.size(); + for (int64_t i = 0; i < num_elements - 1; i++) { + // Ideally, to make sure that the sampled permutation is not biased, rand[i] + // needs to be re-sampled if rand[i] >= 2^128 - (2^128 % (num_elements - + // i)). However, the probability that this happens with any i in [1; + // data.size() - 1] is less than num_elements^2/2^{128}, which is less than + // 2^{-40} for any input vector of size less than 2^{43}. + uint64_t index = i + static_cast(rand[i] % (num_elements - i)); + // Swaps the element at current position with the one at position index. + std::swap(data[i], data[index]); + } + + return absl::OkStatus(); +} + +} // namespace wfa::measurement::common::crypto diff --git a/src/main/cc/any_sketch/crypto/shuffle.h b/src/main/cc/any_sketch/crypto/shuffle.h new file mode 100644 index 0000000..d447831 --- /dev/null +++ b/src/main/cc/any_sketch/crypto/shuffle.h @@ -0,0 +1,37 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_ +#define SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "wfa/any_sketch/secret_share.pb.h" + +namespace wfa::measurement::common::crypto { + +// Shuffles the vector data using Fisher-Yates approach. Let n be the size of +// data, the Fisher-Yates shuffle is as below. +// For i = 0 to (n-2): +// Draws a random value j in the range [i; n-1] +// Swaps data[i] and data[j] +absl::Status SecureShuffleWithSeed(std::vector& data, + const any_sketch::PrngSeed& seed); + +} // namespace wfa::measurement::common::crypto + +#endif // SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_ diff --git a/src/main/cc/math/BUILD.bazel b/src/main/cc/math/BUILD.bazel index 5c5ff99..2baf6e3 100644 --- a/src/main/cc/math/BUILD.bazel +++ b/src/main/cc/math/BUILD.bazel @@ -35,6 +35,7 @@ cc_library( strip_include_prefix = _INCLUDE_PREFIX, deps = [ ":uniform_pseudorandom_generator", + "//src/main/proto/wfa/any_sketch:secret_share_cc_proto", "@boringssl//:ssl", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.cc b/src/main/cc/math/open_ssl_uniform_random_generator.cc index faed781..5063470 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.cc +++ b/src/main/cc/math/open_ssl_uniform_random_generator.cc @@ -155,4 +155,16 @@ OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange( return ret; } +absl::StatusOr> +CreatePrngFromSeed(const PrngSeed &seed) { + // TODO(@ple13): use absl::Cord instead of making vector copies once we can + // upgrade to Protobuf v23.0+. + std::vector key(seed.key().begin(), seed.key().end()); + std::vector iv(seed.iv().begin(), seed.iv().end()); + + ASSIGN_OR_RETURN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + return prng; +} + } // namespace wfa::math diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.h b/src/main/cc/math/open_ssl_uniform_random_generator.h index acdd6ed..84a914d 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.h +++ b/src/main/cc/math/open_ssl_uniform_random_generator.h @@ -28,9 +28,12 @@ #include "absl/status/statusor.h" #include "absl/strings/substitute.h" #include "math/uniform_pseudorandom_generator.h" +#include "wfa/any_sketch/secret_share.pb.h" namespace wfa::math { +using any_sketch::PrngSeed; + // Key length for EVP_aes_256_ctr. // See https://www.openssl.org/docs/man1.1.1/man3/EVP_aes_256_ctr.html inline constexpr int kBytesPerAes256Key = 32; @@ -98,6 +101,11 @@ class OpenSslUniformPseudorandomGenerator EVP_CIPHER_CTX* ctx_; }; +// Create a pseudorandom generator with a PrngSeed which is used to initialize +// the AES 256 counter mode. +absl::StatusOr> +CreatePrngFromSeed(const PrngSeed& seed); + } // namespace wfa::math #endif // SRC_MAIN_CC_MATH_OPEN_SSL_UNIFORM_RANDOM_GENERATOR_H_ diff --git a/src/main/proto/wfa/any_sketch/secret_share.proto b/src/main/proto/wfa/any_sketch/secret_share.proto index bc9b0e5..609792a 100644 --- a/src/main/proto/wfa/any_sketch/secret_share.proto +++ b/src/main/proto/wfa/any_sketch/secret_share.proto @@ -28,13 +28,13 @@ message SecretShareParameter { uint32 modulus = 1; } -message SecretShare { - // Seed to initialize the AES 256 counter mode. - message ShareSeed { - bytes key = 1; - bytes iv = 2; - } +// Seed to initialize the AES 256 counter mode. +message PrngSeed { + bytes key = 1; + bytes iv = 2; +} - ShareSeed share_seed = 1; +message SecretShare { + PrngSeed share_seed = 1; repeated uint32 share_vector = 2; } diff --git a/src/test/cc/any_sketch/crypto/BUILD.bazel b/src/test/cc/any_sketch/crypto/BUILD.bazel index 8ea8514..15838bf 100644 --- a/src/test/cc/any_sketch/crypto/BUILD.bazel +++ b/src/test/cc/any_sketch/crypto/BUILD.bazel @@ -44,3 +44,17 @@ cc_test( "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", ], ) + +cc_test( + name = "shuffle_test", + size = "small", + srcs = [ + "shuffle_test.cc", + ], + deps = [ + "//src/main/cc/any_sketch/crypto:shuffle", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], +) diff --git a/src/test/cc/any_sketch/crypto/shuffle_test.cc b/src/test/cc/any_sketch/crypto/shuffle_test.cc new file mode 100644 index 0000000..d2a4d1e --- /dev/null +++ b/src/test/cc/any_sketch/crypto/shuffle_test.cc @@ -0,0 +1,120 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "any_sketch/crypto/shuffle.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common_cpp/testing/status_macros.h" +#include "common_cpp/testing/status_matchers.h" +#include "gtest/gtest.h" +#include "math/open_ssl_uniform_random_generator.h" + +namespace wfa::any_sketch::crypto { +namespace { + +using any_sketch::PrngSeed; +using measurement::common::crypto::SecureShuffleWithSeed; +using ::wfa::StatusIs; +using ::wfa::math::kBytesPerAes256Iv; +using ::wfa::math::kBytesPerAes256Key; + +TEST(SecureShuffleWithSeed, EmptyInputSucceeds) { + PrngSeed seed; + std::vector data; + absl::Status ret = SecureShuffleWithSeed(data, seed); + ASSERT_EQ(ret, absl::OkStatus()); + ASSERT_EQ(data.size(), 0); +} + +TEST(SecureShuffleWithSeed, InputHasOneElementSucceeds) { + PrngSeed seed; + std::vector data = {1}; + absl::Status ret = SecureShuffleWithSeed(data, seed); + ASSERT_EQ(ret, absl::OkStatus()); + ASSERT_EQ(data.size(), 1); + EXPECT_EQ(data[0], 1); +} + +TEST(SecureShuffleWithSeed, InputSizeAtLeastTwoInvalidSeedFails) { + PrngSeed seed; + *seed.mutable_key() = std::string(kBytesPerAes256Key - 1, 'a'); + *seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b'); + + std::vector data = {1, 2}; + absl::Status ret = SecureShuffleWithSeed(data, seed); + EXPECT_THAT(ret, StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute( + "The uniform pseudorandom generator key has " + "length of $0 bytes but $1 bytes are required.", + seed.key().size(), kBytesPerAes256Key))); +} + +TEST(SecureShuffleWithSeed, ShufflingSucceeds) { + PrngSeed seed; + *seed.mutable_key() = std::string(kBytesPerAes256Key, 'a'); + *seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b'); + + int kInputSize = 100; + std::vector data(kInputSize); + + for (int i = 0; i < kInputSize; i++) { + data[i] = i; + } + std::vector input = data; + absl::Status ret = SecureShuffleWithSeed(data, seed); + ASSERT_EQ(ret, absl::OkStatus()); + ASSERT_EQ(data.size(), kInputSize); + + // Verifies that the output array is different from the input array. + // With a random seed, there is a negligible chance of 1/(kInputsize!) that + // the permutation does not modify the original array and causes this check to + // fail. + EXPECT_NE(input, data); + + // Verifies that the input array and the output array have the same elements. + std::sort(data.begin(), data.end()); + for (int i = 0; i < kInputSize; i++) { + EXPECT_EQ(data[i], i); + } +} + +TEST(SecureShuffleWithSeed, ShufflingWithSameSeedSucceeds) { + PrngSeed seed; + *seed.mutable_key() = std::string(kBytesPerAes256Key, 'a'); + *seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b'); + + int kInputSize = 100; + std::vector data_1(kInputSize); + + for (int i = 0; i < kInputSize; i++) { + data_1[i] = i; + } + + std::vector data_2 = data_1; + + absl::Status ret1 = SecureShuffleWithSeed(data_1, seed); + absl::Status ret2 = SecureShuffleWithSeed(data_2, seed); + + ASSERT_EQ(ret1, absl::OkStatus()); + ASSERT_EQ(ret2, absl::OkStatus()); + ASSERT_EQ(data_1.size(), kInputSize); + ASSERT_EQ(data_2.size(), kInputSize); + + // Verifies that the two vectors are shuffled using the same permutation. + EXPECT_EQ(data_1, data_2); +} + +} // namespace +} // namespace wfa::any_sketch::crypto diff --git a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc index c0be0a6..72b917d 100644 --- a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc +++ b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc @@ -26,7 +26,7 @@ namespace { using ::testing::IsEmpty; TEST(OpenSslUniformPseudorandomGenerator, - CreateTheGeneratorWithValidkeyAndIVSucceeds) { + CreateTheGeneratorWithValidKeyAndIVSucceeds) { std::vector key(kBytesPerAes256Key); std::vector iv(kBytesPerAes256Iv); RAND_bytes(key.data(), key.size()); @@ -37,7 +37,7 @@ TEST(OpenSslUniformPseudorandomGenerator, } TEST(OpenSslUniformPseudorandomGenerator, - CreateTheGeneratorWithInvalidkeySizeFails) { + CreateTheGeneratorWithInvalidKeySizeFails) { std::vector key(kBytesPerAes256Key - 1); std::vector iv(kBytesPerAes256Iv); RAND_bytes(key.data(), key.size()); @@ -68,6 +68,61 @@ TEST(OpenSslUniformPseudorandomGenerator, iv.size(), kBytesPerAes256Iv))); } +TEST(OpenSslUniformPseudorandomGenerator, + CreateTheGeneratorWithValidSeedSucceeds) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + PrngSeed seed; + *seed.mutable_key() = std::string(key.begin(), key.end()); + *seed.mutable_iv() = std::string(iv.begin(), iv.end()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + CreatePrngFromSeed(seed)); +} + +TEST(OpenSslUniformPseudorandomGenerator, + CreateTheGeneratorFromSeedWithInvalidKeyFails) { + std::vector key(kBytesPerAes256Key - 1); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + PrngSeed seed; + *seed.mutable_key() = std::string(key.begin(), key.end()); + *seed.mutable_iv() = std::string(iv.begin(), iv.end()); + + auto prng = CreatePrngFromSeed(seed); + EXPECT_THAT( + prng.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute("The uniform pseudorandom generator key has " + "length of $0 bytes but $1 bytes are required.", + key.size(), kBytesPerAes256Key))); +} + +TEST(OpenSslUniformPseudorandomGenerator, + CreateTheGeneratorFromSeedWithInvalidIVFails) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv - 1); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + PrngSeed seed; + *seed.mutable_key() = std::string(key.begin(), key.end()); + *seed.mutable_iv() = std::string(iv.begin(), iv.end()); + + auto prng = CreatePrngFromSeed(seed); + EXPECT_THAT( + prng.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute("The uniform pseudorandom generator IV has " + "length of $0 bytes but $1 bytes are required.", + iv.size(), kBytesPerAes256Iv))); +} + TEST(OpenSslUniformPseudorandomGenerator, GeneratingNegativeNumberOfRandomBytesFails) { std::vector key(kBytesPerAes256Key);