diff --git a/src/main/cc/any_sketch/crypto/BUILD.bazel b/src/main/cc/any_sketch/crypto/BUILD.bazel index 7786bb8..d23fdb3 100644 --- a/src/main/cc/any_sketch/crypto/BUILD.bazel +++ b/src/main/cc/any_sketch/crypto/BUILD.bazel @@ -54,3 +54,20 @@ cc_binary( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "shuffle", + srcs = [ + "shuffle.cc", + ], + hdrs = [ + "shuffle.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + "//src/main/cc/math:open_ssl_uniform_random_generator", + "//src/main/proto/wfa/any_sketch:secret_share_cc_proto", + "@com_google_absl//absl/status:status", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", + ], +) diff --git a/src/main/cc/any_sketch/crypto/shuffle.cc b/src/main/cc/any_sketch/crypto/shuffle.cc new file mode 100644 index 0000000..76e66c2 --- /dev/null +++ b/src/main/cc/any_sketch/crypto/shuffle.cc @@ -0,0 +1,65 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "any_sketch/crypto/shuffle.h" + +#include "absl/status/status.h" +#include "common_cpp/macros/macros.h" +#include "math/open_ssl_uniform_random_generator.h" + +namespace wfa::measurement::common::crypto { + +using math::CreatePrngFromSeed; +using math::OpenSslUniformPseudorandomGenerator; +using math::UniformPseudorandomGenerator; + +// Shuffles the vector data using Fisher-Yates approach. Let n be the size of +// data, the Fisher-Yates shuffle is as below. +// For i = (n-1) to 1: +// Draws a random value in the range [0; i] +// Swaps data[i] and data[j] +absl::Status ShuffleWithSeed(std::vector& data, + const PrngSeed& seed) { + // Does nothing if the input is empty or has size 1. + if (data.size() <= 1) { + return absl::OkStatus(); + } + + // Initializes the pseudorandom generator using the provided seed. + ASSIGN_OR_RETURN(std::unique_ptr prng, + CreatePrngFromSeed(seed)); + + // Samples random values. + ASSIGN_OR_RETURN( + std::vector arr, + prng->GeneratePseudorandomBytes(data.size() * sizeof(unsigned __int128))); + + unsigned __int128* rand = (unsigned __int128*)arr.data(); + for (uint64_t i = data.size() - 1; i >= 1; i--) { + // Ideally, to make sure that the sampled permutation is not biased, rand[i] + // needs to be re-sampled if rand[i] >= 2^128 - (2^128 % (i+1)). However, + // the probability that this happens with any i in [1; data.size() - 1] is + // less than (data.size())^2/2^{128}, which is less than 2^{-40} for any + // input vector of size less than 2^{43}. + int index = rand[i] % (i + 1); + // Swaps the element at current position with the one at position index. + uint64_t temp = data[i]; + data[i] = data[index]; + data[index] = temp; + } + + return absl::OkStatus(); +} + +} // namespace wfa::measurement::common::crypto diff --git a/src/main/cc/any_sketch/crypto/shuffle.h b/src/main/cc/any_sketch/crypto/shuffle.h new file mode 100644 index 0000000..7a43cfd --- /dev/null +++ b/src/main/cc/any_sketch/crypto/shuffle.h @@ -0,0 +1,34 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_ +#define SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "math/open_ssl_uniform_random_generator.h" +#include "wfa/any_sketch/secret_share.pb.h" + +namespace wfa::measurement::common::crypto { + +using any_sketch::PrngSeed; +using wfa::math::UniformPseudorandomGenerator; + +absl::Status ShuffleWithSeed(std::vector& data, const PrngSeed& seed); + +} // namespace wfa::measurement::common::crypto + +#endif // SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SHUFFLE_H_ diff --git a/src/main/cc/math/BUILD.bazel b/src/main/cc/math/BUILD.bazel index 5c5ff99..2baf6e3 100644 --- a/src/main/cc/math/BUILD.bazel +++ b/src/main/cc/math/BUILD.bazel @@ -35,6 +35,7 @@ cc_library( strip_include_prefix = _INCLUDE_PREFIX, deps = [ ":uniform_pseudorandom_generator", + "//src/main/proto/wfa/any_sketch:secret_share_cc_proto", "@boringssl//:ssl", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.cc b/src/main/cc/math/open_ssl_uniform_random_generator.cc index fca23b9..a333fae 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.cc +++ b/src/main/cc/math/open_ssl_uniform_random_generator.cc @@ -149,4 +149,14 @@ OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange( return ret; } +absl::StatusOr> +CreatePrngFromSeed(const PrngSeed &seed) { + std::vector key(seed.key().begin(), seed.key().end()); + std::vector iv(seed.iv().begin(), seed.iv().end()); + + ASSIGN_OR_RETURN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + return prng; +} + } // namespace wfa::math diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.h b/src/main/cc/math/open_ssl_uniform_random_generator.h index eae407c..762a4ca 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.h +++ b/src/main/cc/math/open_ssl_uniform_random_generator.h @@ -28,9 +28,12 @@ #include "absl/status/statusor.h" #include "absl/strings/substitute.h" #include "math/uniform_pseudorandom_generator.h" +#include "wfa/any_sketch/secret_share.pb.h" namespace wfa::math { +using any_sketch::PrngSeed; + // Key length for EVP_aes_256_ctr. // See https://www.openssl.org/docs/man1.1.1/man3/EVP_aes_256_ctr.html inline constexpr int kBytesPerAes256Key = 32; @@ -98,6 +101,11 @@ class OpenSslUniformPseudorandomGenerator EVP_CIPHER_CTX* ctx_; }; +// Create a pseudorandom generator with a PrngSeed which is used to initialize +// the AES 256 counter mode. +absl::StatusOr> +CreatePrngFromSeed(const PrngSeed& seed); + } // namespace wfa::math #endif // SRC_MAIN_CC_MATH_OPEN_SSL_UNIFORM_RANDOM_GENERATOR_H_ diff --git a/src/main/proto/wfa/any_sketch/secret_share.proto b/src/main/proto/wfa/any_sketch/secret_share.proto index bc9b0e5..609792a 100644 --- a/src/main/proto/wfa/any_sketch/secret_share.proto +++ b/src/main/proto/wfa/any_sketch/secret_share.proto @@ -28,13 +28,13 @@ message SecretShareParameter { uint32 modulus = 1; } -message SecretShare { - // Seed to initialize the AES 256 counter mode. - message ShareSeed { - bytes key = 1; - bytes iv = 2; - } +// Seed to initialize the AES 256 counter mode. +message PrngSeed { + bytes key = 1; + bytes iv = 2; +} - ShareSeed share_seed = 1; +message SecretShare { + PrngSeed share_seed = 1; repeated uint32 share_vector = 2; } diff --git a/src/test/cc/any_sketch/crypto/BUILD.bazel b/src/test/cc/any_sketch/crypto/BUILD.bazel index 8ea8514..15838bf 100644 --- a/src/test/cc/any_sketch/crypto/BUILD.bazel +++ b/src/test/cc/any_sketch/crypto/BUILD.bazel @@ -44,3 +44,17 @@ cc_test( "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", ], ) + +cc_test( + name = "shuffle_test", + size = "small", + srcs = [ + "shuffle_test.cc", + ], + deps = [ + "//src/main/cc/any_sketch/crypto:shuffle", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], +) diff --git a/src/test/cc/any_sketch/crypto/shuffle_test.cc b/src/test/cc/any_sketch/crypto/shuffle_test.cc new file mode 100644 index 0000000..5f90c3b --- /dev/null +++ b/src/test/cc/any_sketch/crypto/shuffle_test.cc @@ -0,0 +1,93 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "any_sketch/crypto/shuffle.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common_cpp/testing/status_macros.h" +#include "common_cpp/testing/status_matchers.h" +#include "gtest/gtest.h" + +namespace wfa::any_sketch::crypto { +namespace { + +using measurement::common::crypto::PrngSeed; +using measurement::common::crypto::ShuffleWithSeed; +using ::wfa::StatusIs; +using ::wfa::math::kBytesPerAes256Iv; +using ::wfa::math::kBytesPerAes256Key; + +TEST(ShuffleWithSeed, EmptyInputSucceeds) { + PrngSeed seed; + std::vector data; + absl::Status ret = ShuffleWithSeed(data, seed); + ASSERT_EQ(ret, absl::OkStatus()); + ASSERT_EQ(data.size(), 0); +} + +TEST(ShuffleWithSeed, InputHasOneElementSucceeds) { + PrngSeed seed; + std::vector data = {1}; + absl::Status ret = ShuffleWithSeed(data, seed); + ASSERT_EQ(ret, absl::OkStatus()); + ASSERT_EQ(data.size(), 1); + EXPECT_EQ(data[0], 1); +} + +TEST(ShuffleWithSeed, InputSizeAtLeastTwoInvalidSeedFails) { + PrngSeed seed; + *seed.mutable_key() = std::string(kBytesPerAes256Key - 1, 'a'); + *seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b'); + + std::vector data = {1, 2}; + absl::Status ret = ShuffleWithSeed(data, seed); + EXPECT_THAT(ret, StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute( + "The uniform pseudorandom generator key has " + "length of $0 bytes but $1 bytes are required.", + seed.key().size(), kBytesPerAes256Key))); +} + +TEST(ShuffleWithSeed, ShufflingSucceeds) { + PrngSeed seed; + *seed.mutable_key() = std::string(kBytesPerAes256Key, 'a'); + *seed.mutable_iv() = std::string(kBytesPerAes256Iv, 'b'); + + int kInputSize = 100; + std::vector data(kInputSize); + + for (int i = 0; i < kInputSize; i++) { + data[i] = i; + } + std::vector input = data; + absl::Status ret = ShuffleWithSeed(data, seed); + ASSERT_EQ(ret, absl::OkStatus()); + ASSERT_EQ(data.size(), kInputSize); + + // Verifies that the output array is different from the input array. + // With a random seed, there is a negligible chance of 1/(kInputsize!) that + // the permutation does not modify the original array and causes this check to + // fail. + EXPECT_NE(input, data); + + // Verifies that the input array and the output array have the same elements. + std::sort(data.begin(), data.end()); + for (int i = 0; i < kInputSize; i++) { + EXPECT_EQ(data[i], i); + } +} + +} // namespace +} // namespace wfa::any_sketch::crypto diff --git a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc index dedf231..7031d92 100644 --- a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc +++ b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc @@ -24,7 +24,7 @@ namespace wfa::math { namespace { TEST(OpenSslUniformPseudorandomGenerator, - CreateTheGeneratorWithValidkeyAndIVSucceeds) { + CreateTheGeneratorWithValidKeyAndIVSucceeds) { std::vector key(kBytesPerAes256Key); std::vector iv(kBytesPerAes256Iv); RAND_bytes(key.data(), key.size()); @@ -35,7 +35,7 @@ TEST(OpenSslUniformPseudorandomGenerator, } TEST(OpenSslUniformPseudorandomGenerator, - CreateTheGeneratorWithInvalidkeySizeFails) { + CreateTheGeneratorWithInvalidKeySizeFails) { std::vector key(kBytesPerAes256Key - 1); std::vector iv(kBytesPerAes256Iv); RAND_bytes(key.data(), key.size()); @@ -66,6 +66,61 @@ TEST(OpenSslUniformPseudorandomGenerator, iv.size(), kBytesPerAes256Iv))); } +TEST(OpenSslUniformPseudorandomGenerator, + CreateTheGeneratorWithValidSeedSucceeds) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + PrngSeed seed; + *seed.mutable_key() = std::string(key.begin(), key.end()); + *seed.mutable_iv() = std::string(iv.begin(), iv.end()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + CreatePrngFromSeed(seed)); +} + +TEST(OpenSslUniformPseudorandomGenerator, + CreateTheGeneratorFromSeedWithInvalidKeyFails) { + std::vector key(kBytesPerAes256Key - 1); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + PrngSeed seed; + *seed.mutable_key() = std::string(key.begin(), key.end()); + *seed.mutable_iv() = std::string(iv.begin(), iv.end()); + + auto prng = CreatePrngFromSeed(seed); + EXPECT_THAT( + prng.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute("The uniform pseudorandom generator key has " + "length of $0 bytes but $1 bytes are required.", + key.size(), kBytesPerAes256Key))); +} + +TEST(OpenSslUniformPseudorandomGenerator, + CreateTheGeneratorFromSeedWithInvalidIVFails) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv - 1); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + PrngSeed seed; + *seed.mutable_key() = std::string(key.begin(), key.end()); + *seed.mutable_iv() = std::string(iv.begin(), iv.end()); + + auto prng = CreatePrngFromSeed(seed); + EXPECT_THAT( + prng.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute("The uniform pseudorandom generator IV has " + "length of $0 bytes but $1 bytes are required.", + iv.size(), kBytesPerAes256Iv))); +} + TEST(OpenSslUniformPseudorandomGenerator, GeneratingNonPositiveNumberOfRandomBytesFails) { std::vector key(kBytesPerAes256Key);