diff --git a/src/main/cc/any_sketch/crypto/BUILD.bazel b/src/main/cc/any_sketch/crypto/BUILD.bazel index d8421e2..7786bb8 100644 --- a/src/main/cc/any_sketch/crypto/BUILD.bazel +++ b/src/main/cc/any_sketch/crypto/BUILD.bazel @@ -21,6 +21,18 @@ cc_library( ], ) +cc_library( + name = "secret_share_generator", + srcs = [":secret_share_generator.cc"], + hdrs = [":secret_share_generator.h"], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + "//src/main/cc/math:open_ssl_uniform_random_generator", + "//src/main/proto/wfa/any_sketch:secret_share_cc_proto", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", + ], +) + cc_library( name = "sketch_encrypter_adapter", srcs = [":sketch_encrypter_adapter.cc"], diff --git a/src/main/cc/any_sketch/crypto/secret_share_generator.cc b/src/main/cc/any_sketch/crypto/secret_share_generator.cc new file mode 100644 index 0000000..2a39082 --- /dev/null +++ b/src/main/cc/any_sketch/crypto/secret_share_generator.cc @@ -0,0 +1,104 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "any_sketch/crypto/secret_share_generator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common_cpp/macros/macros.h" +#include "math/open_ssl_uniform_random_generator.h" + +namespace wfa::any_sketch::crypto { + +using wfa::math::kBytesPerAes256Iv; +using wfa::math::kBytesPerAes256Key; +using wfa::math::OpenSslUniformPseudorandomGenerator; +using wfa::math::UniformPseudorandomGenerator; + +namespace { + +// Computes (x + y) mod modulus and returns the result with constant time. +// The input values x, y, and the output are all in [0, modulus) +absl::StatusOr SubMod(uint32_t x, uint32_t y, uint32_t modulus) { + if (x >= modulus || y >= modulus) { + return absl::InvalidArgumentError(absl::Substitute( + "Inputs must be less than the modulus, which is $0.", modulus)); + } + + uint32_t cmp = (x < y); + return x - y + cmp * modulus; +} + +} // namespace + +absl::StatusOr GenerateSecretShares( + const SecretShareParameter& secret_share_parameter, + const absl::Span input) { + if (input.size() == 0) { + return absl::InvalidArgumentError("Input must be a non-empty vector."); + } + + if (secret_share_parameter.modulus() <= 1) { + return absl::InvalidArgumentError("The modulus must be greater than 1."); + } + + // Verify OpenSSL random generator seed has been seeded with enough entropy. + if (RAND_status() != 1) { + return absl::InternalError( + "OpenSSL random generator has not been seeded with enough entropy."); + } + + // Sample random seed as the first share. + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + + if (RAND_bytes(key.data(), kBytesPerAes256Key) != 1) { + return absl::InternalError("Failed to sample the AES 256 key."); + } + + if (RAND_bytes(iv.data(), kBytesPerAes256Iv) != 1) { + return absl::InternalError("Failed to sample the AES 256 IV."); + } + + // Expanse the seed to get a random string. + ASSIGN_OR_RETURN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + + // share_vector_1 holds the random share generated from the random seed. + ASSIGN_OR_RETURN(std::vector share_vector_1, + prng->GenerateUniformRandomRange( + input.size(), secret_share_parameter.modulus())); + // share_vector_2 is the share computed from the input and share_vector_1. + std::vector share_vector_2(input.size()); + for (int i = 0; i < input.size(); i++) { + // share_vector_2[i] = (input[i] - share_vector_1[i]) mod modulus. + ASSIGN_OR_RETURN( + share_vector_2[i], + SubMod(input[i], share_vector_1[i], secret_share_parameter.modulus())); + } + SecretShare secret_share; + secret_share.mutable_share_vector()->Add(share_vector_2.begin(), + share_vector_2.end()); + std::string key_str(key.begin(), key.end()); + std::string iv_str(iv.begin(), iv.end()); + secret_share.mutable_share_seed()->set_key(key_str); + secret_share.mutable_share_seed()->set_iv(iv_str); + + return secret_share; +} + +} // namespace wfa::any_sketch::crypto diff --git a/src/main/cc/any_sketch/crypto/secret_share_generator.h b/src/main/cc/any_sketch/crypto/secret_share_generator.h new file mode 100644 index 0000000..d0a4377 --- /dev/null +++ b/src/main/cc/any_sketch/crypto/secret_share_generator.h @@ -0,0 +1,34 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SECRET_SHARE_GENERATOR_H_ +#define SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SECRET_SHARE_GENERATOR_H_ + +#include + +#include "absl/status/statusor.h" +#include "wfa/any_sketch/secret_share.pb.h" + +using wfa::any_sketch::SecretShare; +using wfa::any_sketch::SecretShareParameter; + +namespace wfa::any_sketch::crypto { + +absl::StatusOr GenerateSecretShares( + const SecretShareParameter& secret_share_parameter, + const absl::Span input); + +} // namespace wfa::any_sketch::crypto + +#endif // SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SECRET_SHARE_GENERATOR_H_ diff --git a/src/main/cc/math/BUILD.bazel b/src/main/cc/math/BUILD.bazel index 81ab935..5c5ff99 100644 --- a/src/main/cc/math/BUILD.bazel +++ b/src/main/cc/math/BUILD.bazel @@ -38,6 +38,7 @@ cc_library( "@boringssl//:ssl", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", ], ) diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.cc b/src/main/cc/math/open_ssl_uniform_random_generator.cc index 3fbd15e..fca23b9 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.cc +++ b/src/main/cc/math/open_ssl_uniform_random_generator.cc @@ -14,8 +14,10 @@ #include "math/open_ssl_uniform_random_generator.h" +#include #include +#include "common_cpp/macros/macros.h" #include "openssl/rand.h" namespace wfa::math { @@ -68,7 +70,7 @@ OpenSslUniformPseudorandomGenerator::Create( } absl::StatusOr> -OpenSslUniformPseudorandomGenerator::GetPseudorandomBytes(uint64_t size) { +OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(uint64_t size) { if (size == 0) { return absl::InvalidArgumentError( "Number of pseudorandom bytes must be a positive value."); @@ -86,4 +88,65 @@ OpenSslUniformPseudorandomGenerator::GetPseudorandomBytes(uint64_t size) { return ret; } +// Generates uniformly random values in the range [0, modulus) using rejection +// sampling method. +absl::StatusOr> +OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange( + uint64_t size, uint32_t modulus) { + if (size == 0) { + return absl::InvalidArgumentError( + "Number of pseudorandom elements must be a positive value."); + } + + if (modulus <= 1) { + return absl::InvalidArgumentError("The modulus must be greater than 1."); + } + + // Compute the bit length of the modulus. + int bit_length = std::ceil(std::log2(modulus)); + // The number of bytes needed per element. + int bytes_per_value = (bit_length + 7) / 8; + // The mask to extract the last bit_length bits. + uint32_t mask = (1 << bit_length) - 1; + + // Compute the failure probability, which happens when the sampled value is + // greater than or equal to modulus. As 2^{bit_length - 1} < modulus <= + // 2^{bit_length}, the failure probability is guaranteed to be less than 0.5. + double failure_rate = static_cast((1 << bit_length) - modulus) / + static_cast(1 << bit_length); + + std::vector ret; + ret.reserve(size); + + while (ret.size() < size) { + uint64_t current_size = size - ret.size(); + // To get current_size `good` elements, it is expected to sample + // 1 + current_size*(1 + failure_rate/(1-failure_rate)) elements in + // [0, 2^{bit_length}). + uint64_t sample_size = static_cast( + current_size + 1.0 + failure_rate * current_size / (1 - failure_rate)); + + ASSIGN_OR_RETURN(std::vector arr, + GeneratePseudorandomBytes(sample_size * bytes_per_value)); + + // Rejection sampling step. + for (uint64_t i = 0; i < sample_size; i++) { + if (ret.size() >= size) { + break; + } + uint32_t temp = 0; + for (int j = 0; j < bytes_per_value; j++) { + temp = (temp << 8) + arr[i * bytes_per_value + j]; + } + temp &= mask; + + // Accept the value if it is less than modulus. + if (temp < modulus) { + ret.push_back(temp); + } + } + } + return ret; +} + } // namespace wfa::math diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.h b/src/main/cc/math/open_ssl_uniform_random_generator.h index c5eb91f..eae407c 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.h +++ b/src/main/cc/math/open_ssl_uniform_random_generator.h @@ -74,7 +74,7 @@ class OpenSslUniformRandomGenerator { class OpenSslUniformPseudorandomGenerator : public UniformPseudorandomGenerator { public: - // Create a uniform pseudorandom generator from a key and an IV. + // Creates a uniform pseudorandom generator from a key and an IV. // The key and IV needs to have the length of kBytesPerAes256Key and // kBytesPerAes256Iv respectively. static absl::StatusOr> Create( @@ -82,11 +82,15 @@ class OpenSslUniformPseudorandomGenerator const std::vector& iv); // Destructor. - ~OpenSslUniformPseudorandomGenerator() { EVP_CIPHER_CTX_free(ctx_); } + ~OpenSslUniformPseudorandomGenerator() override { EVP_CIPHER_CTX_free(ctx_); } - // Generate a vector of pseudorandom bytes with the given size. - absl::StatusOr> GetPseudorandomBytes( - uint64_t size); + // Generates a vector of pseudorandom bytes with the given size. + absl::StatusOr> GeneratePseudorandomBytes( + uint64_t size) override; + + // Generates a vector of `size` pseudorandom values in the range [0, modulus). + absl::StatusOr> GenerateUniformRandomRange( + uint64_t size, uint32_t modulus) override; private: explicit OpenSslUniformPseudorandomGenerator(EVP_CIPHER_CTX* ctx) diff --git a/src/main/cc/math/uniform_pseudorandom_generator.h b/src/main/cc/math/uniform_pseudorandom_generator.h index 17dd0cb..3c1520b 100644 --- a/src/main/cc/math/uniform_pseudorandom_generator.h +++ b/src/main/cc/math/uniform_pseudorandom_generator.h @@ -39,10 +39,14 @@ class UniformPseudorandomGenerator { UniformPseudorandomGenerator(UniformPseudorandomGenerator&& other) = delete; virtual ~UniformPseudorandomGenerator() = default; - // Generate a vector of pseudorandom bytes with the given size. - virtual absl::StatusOr> GetPseudorandomBytes( + // Generates a vector of pseudorandom bytes with the given size. + virtual absl::StatusOr> GeneratePseudorandomBytes( uint64_t size) = 0; + // Generates a vector of pseudorandom values in the range [0, modulus). + virtual absl::StatusOr> GenerateUniformRandomRange( + uint64_t size, uint32_t modulus) = 0; + protected: UniformPseudorandomGenerator() = default; }; diff --git a/src/main/proto/wfa/any_sketch/BUILD.bazel b/src/main/proto/wfa/any_sketch/BUILD.bazel index 25aa93f..6d973d7 100644 --- a/src/main/proto/wfa/any_sketch/BUILD.bazel +++ b/src/main/proto/wfa/any_sketch/BUILD.bazel @@ -26,3 +26,14 @@ cc_proto_library( name = "differential_privacy_cc_proto", deps = [":differential_privacy_proto"], ) + +proto_library( + name = "secret_share_proto", + srcs = ["secret_share.proto"], + strip_import_prefix = IMPORT_PREFIX, +) + +cc_proto_library( + name = "secret_share_cc_proto", + deps = [":secret_share_proto"], +) diff --git a/src/main/proto/wfa/any_sketch/secret_share.proto b/src/main/proto/wfa/any_sketch/secret_share.proto new file mode 100644 index 0000000..bc9b0e5 --- /dev/null +++ b/src/main/proto/wfa/any_sketch/secret_share.proto @@ -0,0 +1,40 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Protobuffer for the generalized enriched cardinality sketch. +// SketchConfig is configuration of the sketch. +// Sketch is the sketch itself. + +syntax = "proto3"; + +package wfa.any_sketch; + +option java_package = "org.wfanet.anysketch"; +option java_multiple_files = true; +option java_outer_classname = "SecretShareProto"; + +message SecretShareParameter { + uint32 modulus = 1; +} + +message SecretShare { + // Seed to initialize the AES 256 counter mode. + message ShareSeed { + bytes key = 1; + bytes iv = 2; + } + + ShareSeed share_seed = 1; + repeated uint32 share_vector = 2; +} diff --git a/src/test/cc/any_sketch/crypto/BUILD.bazel b/src/test/cc/any_sketch/crypto/BUILD.bazel index 5544abc..8ea8514 100644 --- a/src/test/cc/any_sketch/crypto/BUILD.bazel +++ b/src/test/cc/any_sketch/crypto/BUILD.bazel @@ -28,3 +28,19 @@ cc_test( "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", ], ) + +cc_test( + name = "secret_share_generator_test", + size = "small", + srcs = [ + ":secret_share_generator_test.cc", + ], + deps = [ + "//src/main/cc/any_sketch/crypto:secret_share_generator", + "//src/main/cc/math:open_ssl_uniform_random_generator", + "//src/main/proto/wfa/any_sketch:secret_share_cc_proto", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], +) diff --git a/src/test/cc/any_sketch/crypto/secret_share_generator_test.cc b/src/test/cc/any_sketch/crypto/secret_share_generator_test.cc new file mode 100644 index 0000000..d7983f1 --- /dev/null +++ b/src/test/cc/any_sketch/crypto/secret_share_generator_test.cc @@ -0,0 +1,125 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "any_sketch/crypto/secret_share_generator.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common_cpp/testing/status_macros.h" +#include "common_cpp/testing/status_matchers.h" +#include "gtest/gtest.h" +#include "math/open_ssl_uniform_random_generator.h" + +namespace wfa::any_sketch::crypto { +namespace { + +using math::OpenSslUniformPseudorandomGenerator; +using math::UniformPseudorandomGenerator; + +TEST(AdditiveSecretSharing, EmptyInputVectorFails) { + std::vector input(0); + SecretShareParameter param; + param.set_modulus(128); + auto ret = GenerateSecretShares(param, input); + EXPECT_THAT(ret.status(), StatusIs(absl::StatusCode::kInvalidArgument, + "Input must be a non-empty vector.")); +} + +TEST(AdditiveSecretSharing, InvalidModulusFails) { + std::vector input = {0, 1, 2}; + SecretShareParameter param; + param.set_modulus(1); + auto ret = GenerateSecretShares(param, input); + EXPECT_THAT(ret.status(), StatusIs(absl::StatusCode::kInvalidArgument, + "The modulus must be greater than 1.")); +} + +TEST(AdditiveSecretSharing, SecretShareOverRingZ2kElementsSucceeds) { + std::vector input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + SecretShareParameter param; + param.set_modulus(128); + ASSERT_OK_AND_ASSIGN(SecretShare secret_share, + GenerateSecretShares(param, input)); + + std::string key = secret_share.share_seed().key(); + std::string iv = secret_share.share_seed().iv(); + + std::vector key_vec, iv_vec; + key_vec.insert(key_vec.end(), key.begin(), key.end()); + iv_vec.insert(iv_vec.end(), iv.begin(), iv.end()); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key_vec, iv_vec)); + + ASSERT_OK_AND_ASSIGN( + std::vector share_vector_from_seed, + prng->GenerateUniformRandomRange(input.size(), param.modulus())); + + ASSERT_EQ(secret_share.share_vector().size(), input.size()); + ASSERT_EQ(share_vector_from_seed.size(), input.size()); + + for (int i = 0; i < input.size(); i++) { + ASSERT_EQ(input[i], + (share_vector_from_seed[i] + secret_share.share_vector(i)) % + param.modulus()); + } +} + +TEST(AdditiveSecretSharing, SecretShareOverPrimeFieldElementsSucceeds) { + std::vector input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + SecretShareParameter param; + param.set_modulus(127); + ASSERT_OK_AND_ASSIGN(SecretShare secret_share, + GenerateSecretShares(param, input)); + + std::string key = secret_share.share_seed().key(); + std::string iv = secret_share.share_seed().iv(); + + std::vector key_vec, iv_vec; + key_vec.insert(key_vec.end(), key.begin(), key.end()); + iv_vec.insert(iv_vec.end(), iv.begin(), iv.end()); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key_vec, iv_vec)); + + ASSERT_OK_AND_ASSIGN( + std::vector share_vector_from_seed, + prng->GenerateUniformRandomRange(input.size(), param.modulus())); + + ASSERT_EQ(secret_share.share_vector().size(), input.size()); + ASSERT_EQ(share_vector_from_seed.size(), input.size()); + + for (int i = 0; i < input.size(); i++) { + ASSERT_EQ(input[i], + (share_vector_from_seed[i] + secret_share.share_vector(i)) % + param.modulus()); + } +} + +TEST(AdditiveSecretSharing, InputOutOfBoundFails) { + std::vector input = {0, 1, 2, 3, 4, 5, 6, 7}; + SecretShareParameter param; + param.set_modulus(7); + auto secret_share = GenerateSecretShares(param, input); + EXPECT_THAT(secret_share.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + absl::Substitute( + "Inputs must be less than the modulus, which is $0.", + param.modulus()))); +} + +} // namespace +} // namespace wfa::any_sketch::crypto diff --git a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc index af8ab2f..dedf231 100644 --- a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc +++ b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc @@ -75,7 +75,7 @@ TEST(OpenSslUniformPseudorandomGenerator, ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, OpenSslUniformPseudorandomGenerator::Create(key, iv)); - auto seq = prng->GetPseudorandomBytes(0); + auto seq = prng->GeneratePseudorandomBytes(0); EXPECT_THAT( seq.status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -97,9 +97,9 @@ TEST(OpenSslUniformPseudorandomGenerator, int kNumRandomBytes = 100; ASSERT_OK_AND_ASSIGN(std::vector seq1, - prng1->GetPseudorandomBytes(kNumRandomBytes)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq2, - prng2->GetPseudorandomBytes(kNumRandomBytes)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_EQ(seq1.size(), kNumRandomBytes); ASSERT_EQ(seq2.size(), kNumRandomBytes); @@ -121,13 +121,13 @@ TEST(OpenSslUniformPseudorandomGenerator, int kNumRandomBytes = 100; ASSERT_OK_AND_ASSIGN(std::vector seq10, - prng1->GetPseudorandomBytes(kNumRandomBytes)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq20, - prng2->GetPseudorandomBytes(kNumRandomBytes)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq11, - prng1->GetPseudorandomBytes(kNumRandomBytes)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq21, - prng2->GetPseudorandomBytes(kNumRandomBytes)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_EQ(seq10.size(), kNumRandomBytes); ASSERT_EQ(seq20.size(), kNumRandomBytes); ASSERT_EQ(seq11.size(), kNumRandomBytes); @@ -151,13 +151,13 @@ TEST(OpenSslUniformPseudorandomGenerator, int kNumRandomBytes = 1; ASSERT_OK_AND_ASSIGN(std::vector seq10, - prng1->GetPseudorandomBytes(kNumRandomBytes)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq20, - prng2->GetPseudorandomBytes(kNumRandomBytes)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq11, - prng1->GetPseudorandomBytes(kNumRandomBytes)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_OK_AND_ASSIGN(std::vector seq21, - prng2->GetPseudorandomBytes(kNumRandomBytes)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes)); ASSERT_EQ(seq10.size(), kNumRandomBytes); ASSERT_EQ(seq20.size(), kNumRandomBytes); ASSERT_EQ(seq11.size(), kNumRandomBytes); @@ -182,13 +182,13 @@ TEST(OpenSslUniformPseudorandomGenerator, int kNumRandomBytes1 = 45; int kNumRandomBytes2 = 55; ASSERT_OK_AND_ASSIGN(std::vector seq10, - prng1->GetPseudorandomBytes(kNumRandomBytes1)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes1)); ASSERT_OK_AND_ASSIGN(std::vector seq20, - prng2->GetPseudorandomBytes(kNumRandomBytes2)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes2)); ASSERT_OK_AND_ASSIGN(std::vector seq11, - prng1->GetPseudorandomBytes(kNumRandomBytes2)); + prng1->GeneratePseudorandomBytes(kNumRandomBytes2)); ASSERT_OK_AND_ASSIGN(std::vector seq21, - prng2->GetPseudorandomBytes(kNumRandomBytes1)); + prng2->GeneratePseudorandomBytes(kNumRandomBytes1)); ASSERT_EQ(seq10.size(), kNumRandomBytes1); ASSERT_EQ(seq20.size(), kNumRandomBytes2); ASSERT_EQ(seq11.size(), kNumRandomBytes2); @@ -231,13 +231,13 @@ TEST(OpenSslUniformPseudorandomGenerator, int kBlockSize = 16; ASSERT_OK_AND_ASSIGN(std::vector seq1, - prng->GetPseudorandomBytes(kBlockSize)); + prng->GeneratePseudorandomBytes(kBlockSize)); ASSERT_OK_AND_ASSIGN(std::vector seq2, - prng->GetPseudorandomBytes(kBlockSize)); + prng->GeneratePseudorandomBytes(kBlockSize)); ASSERT_OK_AND_ASSIGN(std::vector seq3, - prng->GetPseudorandomBytes(kBlockSize)); + prng->GeneratePseudorandomBytes(kBlockSize)); ASSERT_OK_AND_ASSIGN(std::vector seq4, - prng->GetPseudorandomBytes(kBlockSize)); + prng->GeneratePseudorandomBytes(kBlockSize)); ASSERT_EQ(seq1.size(), kBlockSize); ASSERT_EQ(seq2.size(), kBlockSize); @@ -250,5 +250,79 @@ TEST(OpenSslUniformPseudorandomGenerator, ASSERT_EQ(seq4, kOutputBlock4); } +TEST(OpenSslUniformPseudorandomGenerator, + GeneratingNonPositiveNumberOfRandomElementsFails) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + + uint32_t kModulus = 128; + uint64_t kNumRandomElements = 0; + auto seq = prng->GenerateUniformRandomRange(kNumRandomElements, kModulus); + EXPECT_THAT( + seq.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + "Number of pseudorandom elements must be a positive value.")); +} + +TEST(OpenSslUniformPseudorandomGenerator, + GeneratingUniformlyRandomWithInvalidModulusFails) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + + uint32_t kModulus = 1; + uint64_t kNumRandomElements = 1; + auto seq = prng->GenerateUniformRandomRange(kNumRandomElements, kModulus); + EXPECT_THAT(seq.status(), StatusIs(absl::StatusCode::kInvalidArgument, + "The modulus must be greater than 1.")); +} + +TEST(OpenSslUniformPseudorandomGenerator, + SampleUniformlyRandomOverRingZ2kElementsSucceeds) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + uint32_t kModulus = 128; + uint64_t kNumRandomElements = 111; + ASSERT_OK_AND_ASSIGN( + std::vector seq, + prng->GenerateUniformRandomRange(kNumRandomElements, kModulus)); + ASSERT_EQ(seq.size(), kNumRandomElements); + for (int i = 0; i < kNumRandomElements; i++) { + ASSERT_LT(seq[i], kModulus); + } +} + +TEST(OpenSslUniformPseudorandomGenerator, + SampleUniformlyRandomOverPrimeFieldElementsSucceeds) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + uint32_t kModulus = 127; + uint64_t kNumRandomElements = 111; + ASSERT_OK_AND_ASSIGN( + std::vector seq, + prng->GenerateUniformRandomRange(kNumRandomElements, kModulus)); + ASSERT_EQ(seq.size(), kNumRandomElements); + for (int i = 0; i < kNumRandomElements; i++) { + ASSERT_LT(seq[i], kModulus); + } +} + } // namespace } // namespace wfa::math