Skip to content

Commit

Permalink
Use signed integer instead of unsigned integer for the PRNG's size ar…
Browse files Browse the repository at this point in the history
…gument. (#40)
  • Loading branch information
ple13 authored Jan 18, 2024
1 parent 6345767 commit f47bce0
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 27 deletions.
24 changes: 15 additions & 9 deletions src/main/cc/math/open_ssl_uniform_random_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ OpenSslUniformPseudorandomGenerator::Create(
}

absl::StatusOr<std::vector<unsigned char>>
OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(uint64_t size) {
if (size == 0) {
OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(int64_t size) {
if (size < 0) {
return absl::InvalidArgumentError(
"Number of pseudorandom bytes must be a positive value.");
"Number of pseudorandom bytes must be a non-negative value.");
}

if (size == 0) {
return std::vector<unsigned char>();
}

std::vector<unsigned char> ret(size, 0);
int length;
if (EVP_EncryptUpdate(ctx_, ret.data(), &length, ret.data(), ret.size()) !=
Expand All @@ -92,16 +97,16 @@ OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(uint64_t size) {
// sampling method.
absl::StatusOr<std::vector<uint32_t>>
OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange(
uint64_t size, uint32_t modulus) {
if (size == 0) {
return absl::InvalidArgumentError(
"Number of pseudorandom elements must be a positive value.");
}

int64_t size, uint32_t modulus) {
if (modulus <= 1) {
return absl::InvalidArgumentError("The modulus must be greater than 1.");
}

if (size < 0) {
return absl::InvalidArgumentError(
"Number of pseudorandom elements must be a non-negative value.");
}

// Compute the bit length of the modulus.
int bit_length = std::ceil(std::log2(modulus));
// The number of bytes needed per element.
Expand Down Expand Up @@ -146,6 +151,7 @@ OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange(
}
}
}

return ret;
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/cc/math/open_ssl_uniform_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ class OpenSslUniformPseudorandomGenerator

// Generates a vector of pseudorandom bytes with the given size.
absl::StatusOr<std::vector<unsigned char>> GeneratePseudorandomBytes(
uint64_t size) override;
int64_t size) override;

// Generates a vector of `size` pseudorandom values in the range [0, modulus).
absl::StatusOr<std::vector<uint32_t>> GenerateUniformRandomRange(
uint64_t size, uint32_t modulus) override;
int64_t size, uint32_t modulus) override;

private:
explicit OpenSslUniformPseudorandomGenerator(EVP_CIPHER_CTX* ctx)
Expand Down
4 changes: 2 additions & 2 deletions src/main/cc/math/uniform_pseudorandom_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class UniformPseudorandomGenerator {

// Generates a vector of pseudorandom bytes with the given size.
virtual absl::StatusOr<std::vector<unsigned char>> GeneratePseudorandomBytes(
uint64_t size) = 0;
int64_t size) = 0;

// Generates a vector of pseudorandom values in the range [0, modulus).
virtual absl::StatusOr<std::vector<uint32_t>> GenerateUniformRandomRange(
uint64_t size, uint32_t modulus) = 0;
int64_t size, uint32_t modulus) = 0;

protected:
UniformPseudorandomGenerator() = default;
Expand Down
55 changes: 41 additions & 14 deletions src/test/cc/math/open_ssl_uniform_random_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
namespace wfa::math {
namespace {

using ::testing::IsEmpty;

TEST(OpenSslUniformPseudorandomGenerator,
CreateTheGeneratorWithValidkeyAndIVSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
Expand Down Expand Up @@ -67,19 +69,30 @@ TEST(OpenSslUniformPseudorandomGenerator,
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingNonPositiveNumberOfRandomBytesFails) {
GeneratingNegativeNumberOfRandomBytesFails) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));
auto seq = prng->GeneratePseudorandomBytes(0);
EXPECT_THAT(
seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
"Number of pseudorandom bytes must be a positive value."));
auto seq = prng->GeneratePseudorandomBytes(-1);
EXPECT_THAT(seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument, "negative"));
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingZeroNumberOfRandomBytesSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));
ASSERT_OK_AND_ASSIGN(auto seq, prng->GeneratePseudorandomBytes(0));
EXPECT_THAT(seq, IsEmpty());
}

TEST(OpenSslUniformPseudorandomGenerator,
Expand Down Expand Up @@ -251,7 +264,7 @@ TEST(OpenSslUniformPseudorandomGenerator,
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingNonPositiveNumberOfRandomElementsFails) {
GeneratingNegativeNumberOfRandomElementsFails) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
Expand All @@ -260,12 +273,26 @@ TEST(OpenSslUniformPseudorandomGenerator,
OpenSslUniformPseudorandomGenerator::Create(key, iv));

uint32_t kModulus = 128;
uint64_t kNumRandomElements = 0;
uint64_t kNumRandomElements = -1;
auto seq = prng->GenerateUniformRandomRange(kNumRandomElements, kModulus);
EXPECT_THAT(
seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
"Number of pseudorandom elements must be a positive value."));
EXPECT_THAT(seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument, "negative"));
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingZeroNumberOfRandomElementsSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));

uint32_t kModulus = 128;
uint64_t kNumRandomElements = 0;
ASSERT_OK_AND_ASSIGN(
auto seq, prng->GenerateUniformRandomRange(kNumRandomElements, kModulus));
EXPECT_THAT(seq, IsEmpty());
}

TEST(OpenSslUniformPseudorandomGenerator,
Expand All @@ -280,8 +307,8 @@ TEST(OpenSslUniformPseudorandomGenerator,
uint32_t kModulus = 1;
uint64_t kNumRandomElements = 1;
auto seq = prng->GenerateUniformRandomRange(kNumRandomElements, kModulus);
EXPECT_THAT(seq.status(), StatusIs(absl::StatusCode::kInvalidArgument,
"The modulus must be greater than 1."));
EXPECT_THAT(seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument, "modulus"));
}

TEST(OpenSslUniformPseudorandomGenerator,
Expand Down

0 comments on commit f47bce0

Please sign in to comment.