From f47bce03c589439964081a665aaa6c5c0ef90e94 Mon Sep 17 00:00:00 2001 From: Phi Date: Thu, 18 Jan 2024 12:09:06 -0500 Subject: [PATCH] Use signed integer instead of unsigned integer for the PRNG's size argument. (#40) --- .../math/open_ssl_uniform_random_generator.cc | 24 +++++--- .../math/open_ssl_uniform_random_generator.h | 4 +- .../cc/math/uniform_pseudorandom_generator.h | 4 +- .../open_ssl_uniform_random_generator_test.cc | 55 ++++++++++++++----- 4 files changed, 60 insertions(+), 27 deletions(-) diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.cc b/src/main/cc/math/open_ssl_uniform_random_generator.cc index fca23b9..faed781 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.cc +++ b/src/main/cc/math/open_ssl_uniform_random_generator.cc @@ -70,11 +70,16 @@ OpenSslUniformPseudorandomGenerator::Create( } absl::StatusOr> -OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(uint64_t size) { - if (size == 0) { +OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(int64_t size) { + if (size < 0) { return absl::InvalidArgumentError( - "Number of pseudorandom bytes must be a positive value."); + "Number of pseudorandom bytes must be a non-negative value."); + } + + if (size == 0) { + return std::vector(); } + std::vector ret(size, 0); int length; if (EVP_EncryptUpdate(ctx_, ret.data(), &length, ret.data(), ret.size()) != @@ -92,16 +97,16 @@ OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(uint64_t size) { // sampling method. absl::StatusOr> OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange( - uint64_t size, uint32_t modulus) { - if (size == 0) { - return absl::InvalidArgumentError( - "Number of pseudorandom elements must be a positive value."); - } - + int64_t size, uint32_t modulus) { if (modulus <= 1) { return absl::InvalidArgumentError("The modulus must be greater than 1."); } + if (size < 0) { + return absl::InvalidArgumentError( + "Number of pseudorandom elements must be a non-negative value."); + } + // Compute the bit length of the modulus. int bit_length = std::ceil(std::log2(modulus)); // The number of bytes needed per element. @@ -146,6 +151,7 @@ OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange( } } } + return ret; } diff --git a/src/main/cc/math/open_ssl_uniform_random_generator.h b/src/main/cc/math/open_ssl_uniform_random_generator.h index eae407c..acdd6ed 100644 --- a/src/main/cc/math/open_ssl_uniform_random_generator.h +++ b/src/main/cc/math/open_ssl_uniform_random_generator.h @@ -86,11 +86,11 @@ class OpenSslUniformPseudorandomGenerator // Generates a vector of pseudorandom bytes with the given size. absl::StatusOr> GeneratePseudorandomBytes( - uint64_t size) override; + int64_t size) override; // Generates a vector of `size` pseudorandom values in the range [0, modulus). absl::StatusOr> GenerateUniformRandomRange( - uint64_t size, uint32_t modulus) override; + int64_t size, uint32_t modulus) override; private: explicit OpenSslUniformPseudorandomGenerator(EVP_CIPHER_CTX* ctx) diff --git a/src/main/cc/math/uniform_pseudorandom_generator.h b/src/main/cc/math/uniform_pseudorandom_generator.h index 3c1520b..0ff884f 100644 --- a/src/main/cc/math/uniform_pseudorandom_generator.h +++ b/src/main/cc/math/uniform_pseudorandom_generator.h @@ -41,11 +41,11 @@ class UniformPseudorandomGenerator { // Generates a vector of pseudorandom bytes with the given size. virtual absl::StatusOr> GeneratePseudorandomBytes( - uint64_t size) = 0; + int64_t size) = 0; // Generates a vector of pseudorandom values in the range [0, modulus). virtual absl::StatusOr> GenerateUniformRandomRange( - uint64_t size, uint32_t modulus) = 0; + int64_t size, uint32_t modulus) = 0; protected: UniformPseudorandomGenerator() = default; diff --git a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc index dedf231..c0be0a6 100644 --- a/src/test/cc/math/open_ssl_uniform_random_generator_test.cc +++ b/src/test/cc/math/open_ssl_uniform_random_generator_test.cc @@ -23,6 +23,8 @@ namespace wfa::math { namespace { +using ::testing::IsEmpty; + TEST(OpenSslUniformPseudorandomGenerator, CreateTheGeneratorWithValidkeyAndIVSucceeds) { std::vector key(kBytesPerAes256Key); @@ -67,7 +69,7 @@ TEST(OpenSslUniformPseudorandomGenerator, } TEST(OpenSslUniformPseudorandomGenerator, - GeneratingNonPositiveNumberOfRandomBytesFails) { + GeneratingNegativeNumberOfRandomBytesFails) { std::vector key(kBytesPerAes256Key); std::vector iv(kBytesPerAes256Iv); RAND_bytes(key.data(), key.size()); @@ -75,11 +77,22 @@ TEST(OpenSslUniformPseudorandomGenerator, ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, OpenSslUniformPseudorandomGenerator::Create(key, iv)); - auto seq = prng->GeneratePseudorandomBytes(0); - EXPECT_THAT( - seq.status(), - StatusIs(absl::StatusCode::kInvalidArgument, - "Number of pseudorandom bytes must be a positive value.")); + auto seq = prng->GeneratePseudorandomBytes(-1); + EXPECT_THAT(seq.status(), + StatusIs(absl::StatusCode::kInvalidArgument, "negative")); +} + +TEST(OpenSslUniformPseudorandomGenerator, + GeneratingZeroNumberOfRandomBytesSucceeds) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + ASSERT_OK_AND_ASSIGN(auto seq, prng->GeneratePseudorandomBytes(0)); + EXPECT_THAT(seq, IsEmpty()); } TEST(OpenSslUniformPseudorandomGenerator, @@ -251,7 +264,7 @@ TEST(OpenSslUniformPseudorandomGenerator, } TEST(OpenSslUniformPseudorandomGenerator, - GeneratingNonPositiveNumberOfRandomElementsFails) { + GeneratingNegativeNumberOfRandomElementsFails) { std::vector key(kBytesPerAes256Key); std::vector iv(kBytesPerAes256Iv); RAND_bytes(key.data(), key.size()); @@ -260,12 +273,26 @@ TEST(OpenSslUniformPseudorandomGenerator, OpenSslUniformPseudorandomGenerator::Create(key, iv)); uint32_t kModulus = 128; - uint64_t kNumRandomElements = 0; + uint64_t kNumRandomElements = -1; auto seq = prng->GenerateUniformRandomRange(kNumRandomElements, kModulus); - EXPECT_THAT( - seq.status(), - StatusIs(absl::StatusCode::kInvalidArgument, - "Number of pseudorandom elements must be a positive value.")); + EXPECT_THAT(seq.status(), + StatusIs(absl::StatusCode::kInvalidArgument, "negative")); +} + +TEST(OpenSslUniformPseudorandomGenerator, + GeneratingZeroNumberOfRandomElementsSucceeds) { + std::vector key(kBytesPerAes256Key); + std::vector iv(kBytesPerAes256Iv); + RAND_bytes(key.data(), key.size()); + RAND_bytes(iv.data(), iv.size()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr prng, + OpenSslUniformPseudorandomGenerator::Create(key, iv)); + + uint32_t kModulus = 128; + uint64_t kNumRandomElements = 0; + ASSERT_OK_AND_ASSIGN( + auto seq, prng->GenerateUniformRandomRange(kNumRandomElements, kModulus)); + EXPECT_THAT(seq, IsEmpty()); } TEST(OpenSslUniformPseudorandomGenerator, @@ -280,8 +307,8 @@ TEST(OpenSslUniformPseudorandomGenerator, uint32_t kModulus = 1; uint64_t kNumRandomElements = 1; auto seq = prng->GenerateUniformRandomRange(kNumRandomElements, kModulus); - EXPECT_THAT(seq.status(), StatusIs(absl::StatusCode::kInvalidArgument, - "The modulus must be greater than 1.")); + EXPECT_THAT(seq.status(), + StatusIs(absl::StatusCode::kInvalidArgument, "modulus")); } TEST(OpenSslUniformPseudorandomGenerator,