From f3cc1f01a1d1682d0d5e3870a9d81550571b1ad4 Mon Sep 17 00:00:00 2001 From: Nick Angelou Date: Sat, 14 Jan 2023 13:39:18 -0600 Subject: [PATCH] Handle intersection when client is larger than server (#151) * use the max of client and server inputs when computing the bits of the bloomfilter * fix gcs * typo * tests * cleanup * Update WASM tests * package lock * use min * fix * fix rollup * revert fpr * revert js tests * use better fpr --- CHANGES.md | 9 +++ package-lock.json | 4 +- package.json | 2 +- private_set_intersection/cpp/bloom_filter.cpp | 8 +- private_set_intersection/cpp/bloom_filter.h | 3 +- private_set_intersection/cpp/gcs.cpp | 8 +- private_set_intersection/cpp/gcs.h | 3 +- private_set_intersection/cpp/gcs_test.cpp | 24 +++--- .../cpp/psi_client_test.cpp | 3 +- private_set_intersection/cpp/psi_server.cpp | 4 +- private_set_intersection/python/tests.py | 78 ++++++++++--------- rollup.config.js | 4 +- tools/package.bzl | 2 +- 13 files changed, 88 insertions(+), 64 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 6e184d6e..6f129197 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +# Version 1.1.0 + +Bugfix: + +- Fixed an issue when computing an intersection with the client's set **larger** + than the server's set. In this case, the intersection reported more values + than possible. e.g. computing an intersection with a client set of 100 values + and server with 10 values sometimes resulted in 12. + # Version 1.0.3 Bugfix: diff --git a/package-lock.json b/package-lock.json index 266caf31..12b3a885 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,6 +1,6 @@ { "name": "@openmined/psi.js", - "version": "1.0.3", + "version": "1.1.0", "lockfileVersion": 2, "requires": true, "packages": { @@ -16929,4 +16929,4 @@ "dev": true } } -} +} \ No newline at end of file diff --git a/package.json b/package.json index 97ac6beb..abbfce6d 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@openmined/psi.js", - "version": "1.0.3", + "version": "1.1.0", "description": "Private Set Intersection for JavaScript", "repository": { "type": "git", diff --git a/private_set_intersection/cpp/bloom_filter.cpp b/private_set_intersection/cpp/bloom_filter.cpp index ebb21d14..6cfc6e51 100644 --- a/private_set_intersection/cpp/bloom_filter.cpp +++ b/private_set_intersection/cpp/bloom_filter.cpp @@ -33,8 +33,12 @@ BloomFilter::BloomFilter( context_(std::move(context)) {} StatusOr> BloomFilter::Create( - double fpr, absl::Span elements) { - ASSIGN_OR_RETURN(auto filter, CreateEmpty(fpr, elements.size())); + double fpr, int64_t num_client_inputs, + absl::Span elements) { + auto num_server_inputs = static_cast(elements.size()); + ASSIGN_OR_RETURN(auto filter, CreateEmpty(fpr, std::max(num_client_inputs, + num_server_inputs))); + filter->Add(elements); // This move seems to be needed for some versions of GCC. See for example this // failing build: diff --git a/private_set_intersection/cpp/bloom_filter.h b/private_set_intersection/cpp/bloom_filter.h index 808205d8..261d933c 100644 --- a/private_set_intersection/cpp/bloom_filter.h +++ b/private_set_intersection/cpp/bloom_filter.h @@ -45,7 +45,8 @@ class BloomFilter { BloomFilter() = delete; static StatusOr> Create( - double fpr, absl::Span elements); + double fpr, int64_t num_client_inputs, + absl::Span elements); // Creates a new Bloom filter. As long as less than `max_elements` are // inserted, the probability of false positives when performing checks diff --git a/private_set_intersection/cpp/gcs.cpp b/private_set_intersection/cpp/gcs.cpp index 36e938c0..c4baedd9 100644 --- a/private_set_intersection/cpp/gcs.cpp +++ b/private_set_intersection/cpp/gcs.cpp @@ -35,12 +35,14 @@ GCS::GCS(std::string golomb, int64_t div, int64_t hash_range, context_(std::move(context)) {} StatusOr> GCS::Create( - double fpr, absl::Span elements) { + double fpr, int64_t num_client_inputs, + absl::Span elements) { if (fpr <= 0 || fpr >= 1) { return absl::InvalidArgumentError("`fpr` must be in (0,1)"); } - - auto hash_range = static_cast(elements.size() / fpr); + auto num_server_inputs = static_cast(elements.size()); + auto hash_range = static_cast( + std::max(num_client_inputs, num_server_inputs) / fpr); std::vector hashes; hashes.reserve(elements.size()); auto context = absl::make_unique<::private_join_and_compute::Context>(); diff --git a/private_set_intersection/cpp/gcs.h b/private_set_intersection/cpp/gcs.h index 47722c5e..07121e65 100644 --- a/private_set_intersection/cpp/gcs.h +++ b/private_set_intersection/cpp/gcs.h @@ -33,7 +33,8 @@ class GCS { GCS() = delete; static StatusOr> Create( - double fpr, absl::Span elements); + double fpr, int64_t num_client_inputs, + absl::Span elements); static StatusOr> CreateFromProtobuf( const psi_proto::ServerSetup& encoded_set); diff --git a/private_set_intersection/cpp/gcs_test.cpp b/private_set_intersection/cpp/gcs_test.cpp index 7790f9a5..7f2bac4d 100644 --- a/private_set_intersection/cpp/gcs_test.cpp +++ b/private_set_intersection/cpp/gcs_test.cpp @@ -30,13 +30,15 @@ namespace { TEST(GCSTest, TestIntersect) { std::vector elements = {"a", "b", "c", "d"}; + std::vector elements2 = {"a", "b", "d", "e"}; + auto max_inputs = + static_cast(std::max(elements.size(), elements2.size())); std::unique_ptr gcs; PSI_ASSERT_OK_AND_ASSIGN( - gcs, - GCS::Create(0.001, absl::MakeConstSpan(&elements[0], elements.size()))); + gcs, GCS::Create(0.001, max_inputs, + absl::MakeConstSpan(&elements[0], elements.size()))); - std::vector elements2 = {"a", "b", "d", "e"}; std::vector intersect = {0, 1, 2}; auto res = @@ -63,7 +65,7 @@ TEST(GCSTest, TestFPR) { std::unique_ptr gcs; PSI_ASSERT_OK_AND_ASSIGN( - gcs, GCS::Create(target_fpr, + gcs, GCS::Create(target_fpr, (int64_t)elements.size(), absl::MakeConstSpan(&elements[0], elements.size()))); // Test 10k elements to measure FPR. @@ -98,8 +100,8 @@ TEST(GCSTest, TestToProtobuf) { std::unique_ptr gcs; PSI_ASSERT_OK_AND_ASSIGN( - gcs, - GCS::Create(fpr, absl::MakeConstSpan(&elements[0], elements.size()))); + gcs, GCS::Create(fpr, (int64_t)elements.size(), + absl::MakeConstSpan(&elements[0], elements.size()))); // Create the protobuf from the GCS and check if it matches. psi_proto::ServerSetup encoded_gcs = gcs->ToProtobuf(); @@ -110,17 +112,19 @@ TEST(GCSTest, TestToProtobuf) { TEST(GCSTest, TestCreateFromProtobuf) { std::vector elements = {"a", "b", "c", "d"}; + std::vector elements2 = {"a", "b", "c", "d", "not present"}; + auto max_inputs = + static_cast(std::max(elements.size(), elements2.size())); std::unique_ptr gcs; PSI_ASSERT_OK_AND_ASSIGN( - gcs, - GCS::Create(0.001, absl::MakeConstSpan(&elements[0], elements.size()))); + gcs, GCS::Create(0.001, max_inputs, + absl::MakeConstSpan(&elements[0], elements.size()))); psi_proto::ServerSetup encoded_gcs = gcs->ToProtobuf(); std::unique_ptr gcs2; PSI_ASSERT_OK_AND_ASSIGN(gcs2, GCS::CreateFromProtobuf(encoded_gcs)); - std::vector elements2 = {"a", "b", "c", "d", "not present"}; auto res = gcs2->Intersect(absl::MakeConstSpan(&elements2[0], elements2.size())); std::vector intersect = {0, 1, 2, 3}; @@ -149,7 +153,7 @@ TEST(GCSTest, TestGolombSize) { for (size_t i = 0; i < sizeof(fpr) / sizeof(double); i++) { std::unique_ptr gcs; PSI_ASSERT_OK_AND_ASSIGN( - gcs, GCS::Create(fpr[i], + gcs, GCS::Create(fpr[i], (int64_t)elements.size(), absl::MakeConstSpan(&elements[0], elements.size()))); auto res = gcs->Intersect(absl::MakeConstSpan(&elements[0], elements.size())); diff --git a/private_set_intersection/cpp/psi_client_test.cpp b/private_set_intersection/cpp/psi_client_test.cpp index 62de6d8b..33b0a208 100644 --- a/private_set_intersection/cpp/psi_client_test.cpp +++ b/private_set_intersection/cpp/psi_client_test.cpp @@ -55,7 +55,8 @@ class PsiClientTest : public ::testing::Test { // Insert server elements into GCS. PSI_ASSERT_OK_AND_ASSIGN( auto gcs, - GCS::Create(fpr, absl::MakeConstSpan(&elements[0], elements.size()))); + GCS::Create(fpr, (int64_t)elements.size(), + absl::MakeConstSpan(&elements[0], elements.size()))); *server_setup = gcs->ToProtobuf(); } diff --git a/private_set_intersection/cpp/psi_server.cpp b/private_set_intersection/cpp/psi_server.cpp index 982036c3..17512c94 100644 --- a/private_set_intersection/cpp/psi_server.cpp +++ b/private_set_intersection/cpp/psi_server.cpp @@ -77,7 +77,7 @@ StatusOr PsiServer::CreateSetupMessage( // Create a GCS and insert elements into it. ASSIGN_OR_RETURN( auto gcs, - GCS::Create(corrected_fpr, + GCS::Create(corrected_fpr, num_client_inputs, absl::MakeConstSpan(&encrypted[0], encrypted.size()))); // Return the GCS as a Protobuf @@ -86,7 +86,7 @@ StatusOr PsiServer::CreateSetupMessage( // Create a Bloom Filter and insert elements into it. ASSIGN_OR_RETURN(auto filter, BloomFilter::Create( - corrected_fpr, + corrected_fpr, num_client_inputs, absl::MakeConstSpan(&encrypted[0], encrypted.size()))); // Return the Bloom Filter as a Protobuf diff --git a/private_set_intersection/python/tests.py b/private_set_intersection/python/tests.py index 24340d7d..1711b0ac 100644 --- a/private_set_intersection/python/tests.py +++ b/private_set_intersection/python/tests.py @@ -26,10 +26,10 @@ def test_client_server(ds, reveal_intersection, duplicate): c = psi.client.CreateWithNewKey(reveal_intersection) s = psi.server.CreateWithNewKey(reveal_intersection) - client_items = ["Element " + str(i) for i in range(1000)] - server_items = ["Element " + str(2 * i) for i in range(10000)] + client_items = ["Element " + str(i) for i in range(100)] + server_items = ["Element " + str(2 * i) for i in range(1000)] - fpr = 1.0 / (1000000000) + fpr = 1.0e-9 setup = dup( duplicate, s.CreateSetupMessage(fpr, len(client_items), server_items, ds), psi.ServerSetup() ) @@ -50,6 +50,36 @@ def test_client_server(ds, reveal_intersection, duplicate): assert intersection <= (1.1 * len(client_items) / 2.0) +@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER]) +@pytest.mark.parametrize("reveal_intersection", [False, True]) +@pytest.mark.parametrize("duplicate", [False, True]) +def test_large_client_small_server(ds, reveal_intersection, duplicate): + c = psi.client.CreateWithNewKey(reveal_intersection) + s = psi.server.CreateWithNewKey(reveal_intersection) + + client_items = ["Element " + str(i) for i in range(1000)] + server_items = ["Element " + str(2 * i) for i in range(10)] + + fpr = 1.0e-9 + setup = dup( + duplicate, s.CreateSetupMessage(fpr, len(client_items), server_items, ds), psi.ServerSetup() + ) + request = dup(duplicate, c.CreateRequest(client_items), psi.Request()) + resp = dup(duplicate, s.ProcessRequest(request), psi.Response()) + + if reveal_intersection: + intersection = c.GetIntersection(setup, resp) + assert len(intersection) <= len(server_items) + # We should have an entry for every item in the server's set offset by + # two due to the input string names + for idx in intersection: + assert client_items[idx] == server_items[int(idx / 2)] + + else: + intersection = c.GetIntersectionSize(setup, resp) + assert intersection <= len(server_items) + + def test_version(): version = psi.__version__ assert re.match(r"[0-9]+[.][0-9]+[.][0-9]+(-[A-Za-z0-9]+)?", version) @@ -83,34 +113,6 @@ def test_client_sanity(reveal_intersection): assert key == newkey -@pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER]) -@pytest.mark.parametrize("reveal_intersection", [False, True]) -def test_server_client(ds, reveal_intersection): - c = psi.client.CreateWithNewKey(reveal_intersection) - s = psi.server.CreateWithNewKey(reveal_intersection) - - client_items = ["Element " + str(i) for i in range(1000)] - server_items = ["Element " + str(2 * i) for i in range(10000)] - - fpr = 1.0 / (1000000000) - setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds) - request = c.CreateRequest(client_items) - resp = s.ProcessRequest(request) - - if reveal_intersection: - intersection = c.GetIntersection(setup, resp) - iset = set(intersection) - for idx in range(len(client_items)): - if idx % 2 == 0: - assert idx in iset - else: - assert idx not in iset - else: - intersection = c.GetIntersectionSize(setup, resp) - assert intersection >= (len(client_items) / 2.0) - assert intersection <= (1.1 * len(client_items) / 2.0) - - @pytest.mark.parametrize("ds", [psi.DataStructure.GCS, psi.DataStructure.BLOOM_FILTER]) @pytest.mark.parametrize("reveal_intersection", [False, True]) def test_serialization_setup_msg(ds, reveal_intersection): @@ -118,7 +120,7 @@ def test_serialization_setup_msg(ds, reveal_intersection): server_items = ["Element " + str(2 * i) for i in range(10000)] - fpr = 1.0 / (1000000000) + fpr = 1.0e-9 setup = s.CreateSetupMessage(fpr, 1000, server_items, ds) buff = setup.SerializeToString() @@ -148,10 +150,10 @@ def test_serialization_response(ds, reveal_intersection): c = psi.client.CreateWithNewKey(reveal_intersection) s = psi.server.CreateWithNewKey(reveal_intersection) - client_items = ["Element " + str(i) for i in range(1000)] - server_items = ["Element " + str(2 * i) for i in range(10000)] + client_items = ["Element " + str(i) for i in range(100)] + server_items = ["Element " + str(2 * i) for i in range(1000)] - fpr = 1.0 / (1000000000) + fpr = 1.0e-9 setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds) req = c.CreateRequest(client_items) resp = s.ProcessRequest(req) @@ -170,10 +172,10 @@ def test_empty_intersection(ds, reveal_intersection): c = psi.client.CreateWithNewKey(reveal_intersection) s = psi.server.CreateWithNewKey(reveal_intersection) - client_items = ["Element " + str(i) for i in range(1000)] - server_items = ["Other " + str(2 * i) for i in range(10000)] + client_items = ["Element " + str(i) for i in range(100)] + server_items = ["Other " + str(2 * i) for i in range(1000)] - fpr = 1.0 / (1000000000) + fpr = 1.0e-9 setup = s.CreateSetupMessage(fpr, len(client_items), server_items, ds) request = c.CreateRequest(client_items) resp = s.ProcessRequest(request) diff --git a/rollup.config.js b/rollup.config.js index e7a62dfd..65620e54 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -20,8 +20,8 @@ const outputs = [ sourcemap: true, format: 'cjs', name: 'PSI', - exports: 'auto' - // plugins: [terser()] + exports: 'auto', + plugins: [terser()] }, plugins: [ // UGLY hack to 'fix' nodejs complaining about 'window' and 'self' diff --git a/tools/package.bzl b/tools/package.bzl index b56f0541..73f2060a 100644 --- a/tools/package.bzl +++ b/tools/package.bzl @@ -1,2 +1,2 @@ """ Version of the current release """ -VERSION_LABEL = "1.0.3" +VERSION_LABEL = "1.1.0"