Skip to content

Commit

Permalink
Handle intersection when client is larger than server (#151)
Browse files Browse the repository at this point in the history
* use the max of client and server inputs when computing the bits of the bloomfilter

* fix gcs

* typo

* tests

* cleanup

* Update WASM tests

* package lock

* use min

* fix

* fix rollup

* revert fpr

* revert js tests

* use better fpr
  • Loading branch information
s0l0ist authored Jan 14, 2023
1 parent 32d6116 commit f3cc1f0
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 64 deletions.
9 changes: 9 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# Version 1.1.0

Bugfix:

- Fixed an issue when computing an intersection with the client's set **larger**
than the server's set. In this case, the intersection reported more values
than possible. e.g. computing an intersection with a client set of 100 values
and server with 10 values sometimes resulted in 12.

# Version 1.0.3

Bugfix:
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@openmined/psi.js",
"version": "1.0.3",
"version": "1.1.0",
"description": "Private Set Intersection for JavaScript",
"repository": {
"type": "git",
Expand Down
8 changes: 6 additions & 2 deletions private_set_intersection/cpp/bloom_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ BloomFilter::BloomFilter(
context_(std::move(context)) {}

StatusOr<std::unique_ptr<BloomFilter>> BloomFilter::Create(
double fpr, absl::Span<const std::string> elements) {
ASSIGN_OR_RETURN(auto filter, CreateEmpty(fpr, elements.size()));
double fpr, int64_t num_client_inputs,
absl::Span<const std::string> elements) {
auto num_server_inputs = static_cast<int64_t>(elements.size());
ASSIGN_OR_RETURN(auto filter, CreateEmpty(fpr, std::max(num_client_inputs,
num_server_inputs)));

filter->Add(elements);
// This move seems to be needed for some versions of GCC. See for example this
// failing build:
Expand Down
3 changes: 2 additions & 1 deletion private_set_intersection/cpp/bloom_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class BloomFilter {
BloomFilter() = delete;

static StatusOr<std::unique_ptr<BloomFilter>> Create(
double fpr, absl::Span<const std::string> elements);
double fpr, int64_t num_client_inputs,
absl::Span<const std::string> elements);

// Creates a new Bloom filter. As long as less than `max_elements` are
// inserted, the probability of false positives when performing checks
Expand Down
8 changes: 5 additions & 3 deletions private_set_intersection/cpp/gcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ GCS::GCS(std::string golomb, int64_t div, int64_t hash_range,
context_(std::move(context)) {}

StatusOr<std::unique_ptr<GCS>> GCS::Create(
double fpr, absl::Span<const std::string> elements) {
double fpr, int64_t num_client_inputs,
absl::Span<const std::string> elements) {
if (fpr <= 0 || fpr >= 1) {
return absl::InvalidArgumentError("`fpr` must be in (0,1)");
}

auto hash_range = static_cast<int64_t>(elements.size() / fpr);
auto num_server_inputs = static_cast<int64_t>(elements.size());
auto hash_range = static_cast<int64_t>(
std::max(num_client_inputs, num_server_inputs) / fpr);
std::vector<int64_t> hashes;
hashes.reserve(elements.size());
auto context = absl::make_unique<::private_join_and_compute::Context>();
Expand Down
3 changes: 2 additions & 1 deletion private_set_intersection/cpp/gcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class GCS {
GCS() = delete;

static StatusOr<std::unique_ptr<GCS>> Create(
double fpr, absl::Span<const std::string> elements);
double fpr, int64_t num_client_inputs,
absl::Span<const std::string> elements);

static StatusOr<std::unique_ptr<GCS>> CreateFromProtobuf(
const psi_proto::ServerSetup& encoded_set);
Expand Down
24 changes: 14 additions & 10 deletions private_set_intersection/cpp/gcs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ namespace {

TEST(GCSTest, TestIntersect) {
std::vector<std::string> elements = {"a", "b", "c", "d"};
std::vector<std::string> elements2 = {"a", "b", "d", "e"};
auto max_inputs =
static_cast<int64_t>(std::max(elements.size(), elements2.size()));

std::unique_ptr<GCS> gcs;
PSI_ASSERT_OK_AND_ASSIGN(
gcs,
GCS::Create(0.001, absl::MakeConstSpan(&elements[0], elements.size())));
gcs, GCS::Create(0.001, max_inputs,
absl::MakeConstSpan(&elements[0], elements.size())));

std::vector<std::string> elements2 = {"a", "b", "d", "e"};
std::vector<int64_t> intersect = {0, 1, 2};

auto res =
Expand All @@ -63,7 +65,7 @@ TEST(GCSTest, TestFPR) {

std::unique_ptr<GCS> gcs;
PSI_ASSERT_OK_AND_ASSIGN(
gcs, GCS::Create(target_fpr,
gcs, GCS::Create(target_fpr, (int64_t)elements.size(),
absl::MakeConstSpan(&elements[0], elements.size())));

// Test 10k elements to measure FPR.
Expand Down Expand Up @@ -98,8 +100,8 @@ TEST(GCSTest, TestToProtobuf) {

std::unique_ptr<GCS> gcs;
PSI_ASSERT_OK_AND_ASSIGN(
gcs,
GCS::Create(fpr, absl::MakeConstSpan(&elements[0], elements.size())));
gcs, GCS::Create(fpr, (int64_t)elements.size(),
absl::MakeConstSpan(&elements[0], elements.size())));

// Create the protobuf from the GCS and check if it matches.
psi_proto::ServerSetup encoded_gcs = gcs->ToProtobuf();
Expand All @@ -110,17 +112,19 @@ TEST(GCSTest, TestToProtobuf) {

TEST(GCSTest, TestCreateFromProtobuf) {
std::vector<std::string> elements = {"a", "b", "c", "d"};
std::vector<std::string> elements2 = {"a", "b", "c", "d", "not present"};
auto max_inputs =
static_cast<int64_t>(std::max(elements.size(), elements2.size()));

std::unique_ptr<GCS> gcs;
PSI_ASSERT_OK_AND_ASSIGN(
gcs,
GCS::Create(0.001, absl::MakeConstSpan(&elements[0], elements.size())));
gcs, GCS::Create(0.001, max_inputs,
absl::MakeConstSpan(&elements[0], elements.size())));

psi_proto::ServerSetup encoded_gcs = gcs->ToProtobuf();
std::unique_ptr<GCS> gcs2;
PSI_ASSERT_OK_AND_ASSIGN(gcs2, GCS::CreateFromProtobuf(encoded_gcs));

std::vector<std::string> elements2 = {"a", "b", "c", "d", "not present"};
auto res =
gcs2->Intersect(absl::MakeConstSpan(&elements2[0], elements2.size()));
std::vector<int64_t> intersect = {0, 1, 2, 3};
Expand Down Expand Up @@ -149,7 +153,7 @@ TEST(GCSTest, TestGolombSize) {
for (size_t i = 0; i < sizeof(fpr) / sizeof(double); i++) {
std::unique_ptr<GCS> gcs;
PSI_ASSERT_OK_AND_ASSIGN(
gcs, GCS::Create(fpr[i],
gcs, GCS::Create(fpr[i], (int64_t)elements.size(),
absl::MakeConstSpan(&elements[0], elements.size())));
auto res =
gcs->Intersect(absl::MakeConstSpan(&elements[0], elements.size()));
Expand Down
3 changes: 2 additions & 1 deletion private_set_intersection/cpp/psi_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class PsiClientTest : public ::testing::Test {
// Insert server elements into GCS.
PSI_ASSERT_OK_AND_ASSIGN(
auto gcs,
GCS::Create(fpr, absl::MakeConstSpan(&elements[0], elements.size())));
GCS::Create(fpr, (int64_t)elements.size(),
absl::MakeConstSpan(&elements[0], elements.size())));
*server_setup = gcs->ToProtobuf();
}

Expand Down
4 changes: 2 additions & 2 deletions private_set_intersection/cpp/psi_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ StatusOr<psi_proto::ServerSetup> PsiServer::CreateSetupMessage(
// Create a GCS and insert elements into it.
ASSIGN_OR_RETURN(
auto gcs,
GCS::Create(corrected_fpr,
GCS::Create(corrected_fpr, num_client_inputs,
absl::MakeConstSpan(&encrypted[0], encrypted.size())));

// Return the GCS as a Protobuf
Expand All @@ -86,7 +86,7 @@ StatusOr<psi_proto::ServerSetup> PsiServer::CreateSetupMessage(
// Create a Bloom Filter and insert elements into it.
ASSIGN_OR_RETURN(auto filter,
BloomFilter::Create(
corrected_fpr,
corrected_fpr, num_client_inputs,
absl::MakeConstSpan(&encrypted[0], encrypted.size())));

// Return the Bloom Filter as a Protobuf
Expand Down
78 changes: 40 additions & 38 deletions private_set_intersection/python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def test_client_server(ds, reveal_intersection, duplicate):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Element " + str(2 * i) for i in range(10000)]
client_items = ["Element " + str(i) for i in range(100)]
server_items = ["Element " + str(2 * i) for i in range(1000)]

fpr = 1.0 / (1000000000)
fpr = 1.0e-9
setup = dup(
duplicate, s.CreateSetupMessage(fpr, len(client_items), server_items, ds), psi.ServerSetup()
)
Expand All @@ -50,6 +50,36 @@ def test_client_server(ds, reveal_intersection, duplicate):
assert intersection <= (1.1 * len(client_items) / 2.0)


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
@pytest.mark.parametrize("duplicate", [False, True])
def test_large_client_small_server(ds, reveal_intersection, duplicate):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Element " + str(2 * i) for i in range(10)]

fpr = 1.0e-9
setup = dup(
duplicate, s.CreateSetupMessage(fpr, len(client_items), server_items, ds), psi.ServerSetup()
)
request = dup(duplicate, c.CreateRequest(client_items), psi.Request())
resp = dup(duplicate, s.ProcessRequest(request), psi.Response())

if reveal_intersection:
intersection = c.GetIntersection(setup, resp)
assert len(intersection) <= len(server_items)
# We should have an entry for every item in the server's set offset by
# two due to the input string names
for idx in intersection:
assert client_items[idx] == server_items[int(idx / 2)]

else:
intersection = c.GetIntersectionSize(setup, resp)
assert intersection <= len(server_items)


def test_version():
version = psi.__version__
assert re.match(r"[0-9]+[.][0-9]+[.][0-9]+(-[A-Za-z0-9]+)?", version)
Expand Down Expand Up @@ -83,42 +113,14 @@ def test_client_sanity(reveal_intersection):
assert key == newkey


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_server_client(ds, reveal_intersection):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Element " + str(2 * i) for i in range(10000)]

fpr = 1.0 / (1000000000)
setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds)
request = c.CreateRequest(client_items)
resp = s.ProcessRequest(request)

if reveal_intersection:
intersection = c.GetIntersection(setup, resp)
iset = set(intersection)
for idx in range(len(client_items)):
if idx % 2 == 0:
assert idx in iset
else:
assert idx not in iset
else:
intersection = c.GetIntersectionSize(setup, resp)
assert intersection >= (len(client_items) / 2.0)
assert intersection <= (1.1 * len(client_items) / 2.0)


@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER])
@pytest.mark.parametrize("reveal_intersection", [False, True])
def test_serialization_setup_msg(ds, reveal_intersection):
s = psi.server.CreateWithNewKey(reveal_intersection)

server_items = ["Element " + str(2 * i) for i in range(10000)]

fpr = 1.0 / (1000000000)
fpr = 1.0e-9
setup = s.CreateSetupMessage(fpr, 1000, server_items, ds)

buff = setup.SerializeToString()
Expand Down Expand Up @@ -148,10 +150,10 @@ def test_serialization_response(ds, reveal_intersection):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Element " + str(2 * i) for i in range(10000)]
client_items = ["Element " + str(i) for i in range(100)]
server_items = ["Element " + str(2 * i) for i in range(1000)]

fpr = 1.0 / (1000000000)
fpr = 1.0e-9
setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds)
req = c.CreateRequest(client_items)
resp = s.ProcessRequest(req)
Expand All @@ -170,10 +172,10 @@ def test_empty_intersection(ds, reveal_intersection):
c = psi.client.CreateWithNewKey(reveal_intersection)
s = psi.server.CreateWithNewKey(reveal_intersection)

client_items = ["Element " + str(i) for i in range(1000)]
server_items = ["Other " + str(2 * i) for i in range(10000)]
client_items = ["Element " + str(i) for i in range(100)]
server_items = ["Other " + str(2 * i) for i in range(1000)]

fpr = 1.0 / (1000000000)
fpr = 1.0e-9
setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds)
request = c.CreateRequest(client_items)
resp = s.ProcessRequest(request)
Expand Down
4 changes: 2 additions & 2 deletions rollup.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ const outputs = [
sourcemap: true,
format: 'cjs',
name: 'PSI',
exports: 'auto'
// plugins: [terser()]
exports: 'auto',
plugins: [terser()]
},
plugins: [
// UGLY hack to 'fix' nodejs complaining about 'window' and 'self'
Expand Down
2 changes: 1 addition & 1 deletion tools/package.bzl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
""" Version of the current release """
VERSION_LABEL = "1.0.3"
VERSION_LABEL = "1.1.0"

0 comments on commit f3cc1f0

Please sign in to comment.