Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the secret share generator function. #37

Merged
merged 12 commits into from
Jan 8, 2024
12 changes: 12 additions & 0 deletions src/main/cc/any_sketch/crypto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ cc_library(
],
)

cc_library(
name = "secret_share_generator",
srcs = [":secret_share_generator.cc"],
hdrs = [":secret_share_generator.h"],
strip_include_prefix = _INCLUDE_PREFIX,
deps = [
"//src/main/cc/math:open_ssl_uniform_random_generator",
"//src/main/proto/wfa/any_sketch:secret_share_cc_proto",
"@wfa_common_cpp//src/main/cc/common_cpp/macros",
],
)

cc_library(
name = "sketch_encrypter_adapter",
srcs = [":sketch_encrypter_adapter.cc"],
Expand Down
106 changes: 106 additions & 0 deletions src/main/cc/any_sketch/crypto/secret_share_generator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "any_sketch/crypto/secret_share_generator.h"

#include <memory>
#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "common_cpp/macros/macros.h"
#include "math/open_ssl_uniform_random_generator.h"

namespace wfa::any_sketch::crypto {

using wfa::math::kBytesPerAes256Iv;
using wfa::math::kBytesPerAes256Key;
using wfa::math::OpenSslUniformPseudorandomGenerator;
using wfa::math::UniformPseudorandomGenerator;

namespace {

// This function computes (x + y) mod modulus and returns the result.
// The input values x, y, and the output are all in [0, modulus)
absl::StatusOr<uint32_t> SubMod(uint32_t x, uint32_t y, uint32_t modulus) {
if (x >= modulus || y >= modulus) {
return absl::InvalidArgumentError(absl::Substitute(
"Inputs must be less than the modulus, which is $0.", modulus));
}
if (x >= y) {
return x - y;
} else {
return x + (modulus - y);
}
}

} // namespace

absl::StatusOr<SecretShare> GenerateSecretShares(
const SecretShareParameter& secret_share_parameter,
const absl::Span<const uint32_t> input) {
if (input.size() == 0) {
return absl::InvalidArgumentError("Input must be a non-empty vector.");
}

if (secret_share_parameter.modulus() <= 1) {
return absl::InvalidArgumentError("The modulus must be greater than 1.");
}

// Verify OpenSSL random generator seed has been seeded with enough entropy.
if (RAND_status() != 1) {
return absl::InternalError(
"OpenSSL random generator has been seeded with enough entropy.");
}

// Sample random seed as the first share.
std::vector<unsigned char> key(kBytesPerAes256Key);
std::vector<unsigned char> iv(kBytesPerAes256Iv);

if (RAND_bytes(key.data(), kBytesPerAes256Key) != 1) {
return absl::InternalError("Failed to sample the AES 256 key.");
}

if (RAND_bytes(iv.data(), kBytesPerAes256Iv) != 1) {
return absl::InternalError("Failed to sample the AES 256 IV.");
}

// Expanse the seed to get a random string.
ASSIGN_OR_RETURN(std::unique_ptr<UniformPseudorandomGenerator> prng,
OpenSslUniformPseudorandomGenerator::Create(key, iv));

// share_vector_1 holds the random share generated from the random seed.
ASSIGN_OR_RETURN(std::vector<uint32_t> share_vector_1,
prng->GenerateUniformRandomRange(
input.size(), secret_share_parameter.modulus()));
// share_vector_2 is the share computed from the input and share_vector_1.
std::vector<uint32_t> share_vector_2(input.size());
for (int i = 0; i < input.size(); i++) {
// share_vector_2[i] = (input[i] - share_vector_1[i]) mod modulus.
ASSIGN_OR_RETURN(
share_vector_2[i],
SubMod(input[i], share_vector_1[i], secret_share_parameter.modulus()));
}
SecretShare secret_share;
secret_share.mutable_share_vector()->Add(share_vector_2.begin(),
share_vector_2.end());
std::string key_str(key.begin(), key.end());
std::string iv_str(iv.begin(), iv.end());
secret_share.mutable_share_seed()->set_key(key_str);
secret_share.mutable_share_seed()->set_iv(iv_str);

return secret_share;
}

} // namespace wfa::any_sketch::crypto
34 changes: 34 additions & 0 deletions src/main/cc/any_sketch/crypto/secret_share_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SECRET_SHARE_GENERATOR_H_
#define SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SECRET_SHARE_GENERATOR_H_

#include <vector>

#include "absl/status/statusor.h"
#include "wfa/any_sketch/secret_share.pb.h"

using wfa::any_sketch::SecretShare;
using wfa::any_sketch::SecretShareParameter;

namespace wfa::any_sketch::crypto {

absl::StatusOr<SecretShare> GenerateSecretShares(
const SecretShareParameter& secret_share_parameter,
const absl::Span<const uint32_t> input);

} // namespace wfa::any_sketch::crypto

#endif // SRC_MAIN_CC_ANY_SKETCH_CRYPTO_SECRET_SHARE_GENERATOR_H_
1 change: 1 addition & 0 deletions src/main/cc/math/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library(
"@boringssl//:ssl",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@wfa_common_cpp//src/main/cc/common_cpp/macros",
],
)

Expand Down
65 changes: 64 additions & 1 deletion src/main/cc/math/open_ssl_uniform_random_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

#include "math/open_ssl_uniform_random_generator.h"

#include <functional>
#include <stdexcept>

#include "common_cpp/macros/macros.h"
#include "openssl/rand.h"

namespace wfa::math {
Expand Down Expand Up @@ -68,7 +70,7 @@ OpenSslUniformPseudorandomGenerator::Create(
}

absl::StatusOr<std::vector<unsigned char>>
OpenSslUniformPseudorandomGenerator::GetPseudorandomBytes(uint64_t size) {
OpenSslUniformPseudorandomGenerator::GeneratePseudorandomBytes(uint64_t size) {
if (size == 0) {
return absl::InvalidArgumentError(
"Number of pseudorandom bytes must be a positive value.");
Expand All @@ -86,4 +88,65 @@ OpenSslUniformPseudorandomGenerator::GetPseudorandomBytes(uint64_t size) {
return ret;
}

// Rejection sampling is used to generate uniformly random values in the range
// [0, modulus).
absl::StatusOr<std::vector<uint32_t>>
OpenSslUniformPseudorandomGenerator::GenerateUniformRandomRange(
uint64_t size, uint32_t modulus) {
if (size == 0) {
return absl::InvalidArgumentError(
"Number of pseudorandom elements must be a positive value.");
}

if (modulus <= 1) {
return absl::InvalidArgumentError("The modulus must be greater than 1.");
}

// Compute the bit length of the modulus.
int bit_length = std::ceil(std::log2(modulus));
// The number of bytes needed per element.
int bytes_per_value = (bit_length + 7) / 8;
// The mask to extract the last bit_length bits.
uint32_t mask = (1 << bit_length) - 1;

// Compute the failure probability, which happens when the sampled value is
// greater than or equal to modulus. As 2^{bit_length - 1} < modulus <=
// 2^{bit_length}, the failure probability is guaranteed to be less than 0.5.
double failure_rate = static_cast<double>((1 << bit_length) - modulus) /
static_cast<double>(1 << bit_length);

std::vector<uint32_t> ret;
ret.reserve(size);

while (ret.size() < size) {
uint64_t current_size = size - ret.size();
// To get current_size `good` elements, it is expected to sample
// 1 + current_size*(1 + failure_rate/(1-failure_rate)) elements in
// [0, 2^{bit_length}).
uint64_t sample_size = static_cast<uint64_t>(
current_size + 1.0 + failure_rate * current_size / (1 - failure_rate));

ASSIGN_OR_RETURN(std::vector<unsigned char> arr,
GeneratePseudorandomBytes(sample_size * bytes_per_value));

// Rejection sampling step.
for (uint64_t i = 0; i < sample_size; i++) {
if (ret.size() >= size) {
break;
}
uint32_t temp = 0;
for (int j = 0; j < bytes_per_value; j++) {
temp = (temp << 8) + arr[i * bytes_per_value + j];
}
temp &= mask;

// Accept the value if it is less than modulus.
if (temp < modulus) {
ret.push_back(temp);
}
}
}
return ret;
}

} // namespace wfa::math
14 changes: 9 additions & 5 deletions src/main/cc/math/open_ssl_uniform_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,23 @@ class OpenSslUniformRandomGenerator {
class OpenSslUniformPseudorandomGenerator
: public UniformPseudorandomGenerator {
public:
// Create a uniform pseudorandom generator from a key and an IV.
// Creates a uniform pseudorandom generator from a key and an IV.
// The key and IV needs to have the length of kBytesPerAes256Key and
// kBytesPerAes256Iv respectively.
static absl::StatusOr<std::unique_ptr<UniformPseudorandomGenerator>> Create(
const std::vector<unsigned char>& key,
const std::vector<unsigned char>& iv);

// Destructor.
~OpenSslUniformPseudorandomGenerator() { EVP_CIPHER_CTX_free(ctx_); }
~OpenSslUniformPseudorandomGenerator() override { EVP_CIPHER_CTX_free(ctx_); }

// Generate a vector of pseudorandom bytes with the given size.
absl::StatusOr<std::vector<unsigned char>> GetPseudorandomBytes(
uint64_t size);
// Generates a vector of pseudorandom bytes with the given size.
absl::StatusOr<std::vector<unsigned char>> GeneratePseudorandomBytes(
uint64_t size) override;

// Generates a vector of `size` pseudorandom values in the range [0, modulus).
absl::StatusOr<std::vector<uint32_t>> GenerateUniformRandomRange(
uint64_t size, uint32_t modulus) override;

private:
explicit OpenSslUniformPseudorandomGenerator(EVP_CIPHER_CTX* ctx)
Expand Down
8 changes: 6 additions & 2 deletions src/main/cc/math/uniform_pseudorandom_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ class UniformPseudorandomGenerator {
UniformPseudorandomGenerator(UniformPseudorandomGenerator&& other) = delete;
virtual ~UniformPseudorandomGenerator() = default;

// Generate a vector of pseudorandom bytes with the given size.
virtual absl::StatusOr<std::vector<unsigned char>> GetPseudorandomBytes(
// Generates a vector of pseudorandom bytes with the given size.
virtual absl::StatusOr<std::vector<unsigned char>> GeneratePseudorandomBytes(
uint64_t size) = 0;

// Generates a vector of pseudorandom values in the range [0, modulus).
virtual absl::StatusOr<std::vector<uint32_t>> GenerateUniformRandomRange(
uint64_t size, uint32_t modulus) = 0;

protected:
UniformPseudorandomGenerator() = default;
};
Expand Down
11 changes: 11 additions & 0 deletions src/main/proto/wfa/any_sketch/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@ cc_proto_library(
name = "differential_privacy_cc_proto",
deps = [":differential_privacy_proto"],
)

proto_library(
name = "secret_share_proto",
srcs = ["secret_share.proto"],
strip_import_prefix = IMPORT_PREFIX,
)

cc_proto_library(
name = "secret_share_cc_proto",
deps = [":secret_share_proto"],
)
40 changes: 40 additions & 0 deletions src/main/proto/wfa/any_sketch/secret_share.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Protobuffer for the generalized enriched cardinality sketch.
// SketchConfig is configuration of the sketch.
// Sketch is the sketch itself.

syntax = "proto3";

package wfa.any_sketch;

option java_package = "org.wfanet.anysketch";
option java_multiple_files = true;
option java_outer_classname = "SecretShareProto";

message SecretShareParameter {
uint32 modulus = 1;
}

message SecretShare {
// Seed to initialize the AES 256 counter mode.
message ShareSeed {
bytes key = 1;
bytes iv = 2;
}

ShareSeed share_seed = 1;
repeated uint32 share_vector = 2;
}
16 changes: 16 additions & 0 deletions src/test/cc/any_sketch/crypto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,19 @@ cc_test(
"@wfa_common_cpp//src/main/cc/common_cpp/testing:status",
],
)

cc_test(
name = "secret_share_generator_test",
size = "small",
srcs = [
":secret_share_generator_test.cc",
],
deps = [
"//src/main/cc/any_sketch/crypto:secret_share_generator",
"//src/main/cc/math:open_ssl_uniform_random_generator",
"//src/main/proto/wfa/any_sketch:secret_share_cc_proto",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@wfa_common_cpp//src/main/cc/common_cpp/testing:status",
],
)
Loading
Loading