Skip to content

Commit

Permalink
Add function to sample non zero random vector. (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
ple13 authored May 2, 2024
1 parent 24b6a4c commit f6568de
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/main/cc/math/open_ssl_uniform_random_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,46 @@ OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange(
return ret;
}

// Generates uniformly random values in the range [1, modulus).
absl::StatusOr<std::vector<uint32_t>>
OpenSslUniformPseudorandomGenerator::GenerateNonZeroUniformRandomRange(
int64_t size, uint32_t modulus) {
if (modulus <= 1) {
return absl::InvalidArgumentError("The modulus must be greater than 1.");
}

if (size < 0) {
return absl::InvalidArgumentError(
"Number of pseudorandom elements must be a non-negative value.");
}

std::vector<uint32_t> ret;
ret.reserve(size);
// Compute the failure chance, which happens when the sampled value is 0.
double failure_rate = 1.0 / static_cast<double>(modulus);
while (ret.size() < size) {
int64_t current_size = size - ret.size();
// To get current_size `good` elements, it is expected to sample
// 1 + current_size*(1 + failure_rate/(1-failure_rate)) elements.
int64_t sample_size = static_cast<int64_t>(
current_size + 1.0 + failure_rate * current_size / (1 - failure_rate));
ASSIGN_OR_RETURN(std::vector<uint32_t> arr,
GenerateUniformRandomRange(current_size, modulus));
// Rejection sampling step.
for (int64_t i = 0; i < sample_size; i++) {
if (ret.size() >= size) {
break;
}

// Accept the value if the element is not zero.
if (arr[i] > 0) {
ret.push_back(arr[i]);
}
}
}
return ret;
}

absl::StatusOr<std::unique_ptr<UniformPseudorandomGenerator>>
CreatePrngFromSeed(const PrngSeed &seed) {
// TODO(@ple13): use absl::Cord instead of making vector copies once we can
Expand Down
4 changes: 4 additions & 0 deletions src/main/cc/math/open_ssl_uniform_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class OpenSslUniformPseudorandomGenerator
absl::StatusOr<std::vector<uint32_t>> GenerateUniformRandomRange(
int64_t size, uint32_t modulus) override;

// Generates a vector of `size` pseudorandom values in the range [1, modulus).
absl::StatusOr<std::vector<uint32_t>> GenerateNonZeroUniformRandomRange(
int64_t size, uint32_t modulus) override;

private:
explicit OpenSslUniformPseudorandomGenerator(EVP_CIPHER_CTX* ctx)
: ctx_(std::move(ctx)) {}
Expand Down
4 changes: 4 additions & 0 deletions src/main/cc/math/uniform_pseudorandom_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class UniformPseudorandomGenerator {
virtual absl::StatusOr<std::vector<uint32_t>> GenerateUniformRandomRange(
int64_t size, uint32_t modulus) = 0;

// Generates a vector of pseudorandom values in the range [1, modulus).
virtual absl::StatusOr<std::vector<uint32_t>>
GenerateNonZeroUniformRandomRange(int64_t size, uint32_t modulus) = 0;

protected:
UniformPseudorandomGenerator() = default;
};
Expand Down
92 changes: 92 additions & 0 deletions src/test/cc/math/open_ssl_uniform_random_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,5 +406,97 @@ TEST(OpenSslUniformPseudorandomGenerator,
}
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingNegativeNumberOfNonZeroRandomElementsFails) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));

uint32_t kModulus = 128;
uint64_t kNumRandomElements = -1;
auto seq =
prng->GenerateNonZeroUniformRandomRange(kNumRandomElements, kModulus);
EXPECT_THAT(seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument, "negative"));
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingZeroNumberOfNonZeroRandomElementsSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));

uint32_t kModulus = 128;
uint64_t kNumRandomElements = 0;
ASSERT_OK_AND_ASSIGN(auto seq, prng->GenerateNonZeroUniformRandomRange(
kNumRandomElements, kModulus));
EXPECT_THAT(seq, IsEmpty());
}

TEST(OpenSslUniformPseudorandomGenerator,
GeneratingNonZeroUniformlyRandomWithInvalidModulusFails) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));

uint32_t kModulus = 1;
uint64_t kNumRandomElements = 1;
auto seq =
prng->GenerateNonZeroUniformRandomRange(kNumRandomElements, kModulus);
EXPECT_THAT(seq.status(),
StatusIs(absl::StatusCode::kInvalidArgument, "modulus"));
}

TEST(OpenSslUniformPseudorandomGenerator,
SampleNonZeroUniformlyRandomOverRingZ2kElementsSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));
uint32_t kModulus = 128;
uint64_t kNumRandomElements = 111;
ASSERT_OK_AND_ASSIGN(
std::vector<uint32_t> seq,
prng->GenerateNonZeroUniformRandomRange(kNumRandomElements, kModulus));
ASSERT_EQ(seq.size(), kNumRandomElements);
for (int i = 0; i < kNumRandomElements; i++) {
ASSERT_GT(seq[i], 0);
ASSERT_LT(seq[i], kModulus);
}
}

TEST(OpenSslUniformPseudorandomGenerator,
SampleNonZeroUniformlyRandomOverPrimeFieldElementsSucceeds) {
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);
RAND_bytes(key.data(), key.size());
RAND_bytes(iv.data(), iv.size());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));
uint32_t kModulus = 127;
uint64_t kNumRandomElements = 111;
ASSERT_OK_AND_ASSIGN(
std::vector<uint32_t> seq,
prng->GenerateNonZeroUniformRandomRange(kNumRandomElements, kModulus));
ASSERT_EQ(seq.size(), kNumRandomElements);
for (int i = 0; i < kNumRandomElements; i++) {
ASSERT_GT(seq[i], 0);
ASSERT_LT(seq[i], kModulus);
}
}

} // namespace
} // namespace wfa::math

0 comments on commit f6568de

Please sign in to comment.