-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add code for testing if a VID is a member of a sampling bucket. (#8)
* Adds code for testing if a VID is a member of a sampling bucket. * Updates .bazelversion to work around apple_common error * Reformats build file * Adds unit test for checking that hashed VIDs are between 0 and 1 * Updates in response to comments from Sanjay * Updates in response to comments from Craig Wright. * Updated in response to comments from Craig Wright and Sanjay Vasandani. * Changed a stray 52 to 23. * Updates per comments from Craig Wright
- Loading branch information
1 parent
c8ab0f3
commit a89f9cc
Showing
5 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
4.2.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
load("@rules_java//java:defs.bzl", "java_library") | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
java_library( | ||
name = "vid_sampler", | ||
srcs = [ | ||
"VidSampler.java", | ||
], | ||
deps = [ | ||
"@maven//:com_google_guava_guava", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright 2022 The Cross-Media Measurement Authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package org.wfanet.estimation; | ||
|
||
import com.google.common.hash.HashFunction; | ||
import com.google.common.hash.Hasher; | ||
|
||
/** | ||
* Filter VIDs by sampling bucket. | ||
* | ||
* <p>NOTE: IT IS IMPORTANT THAT THE SAME HASHING FUNCTION BE USED FOR ANY TWO SKETCHES THAT MIGHT | ||
* BE COMBINED BY THE SMPC, INCLUDING THE CASE WHEN THOSE SKETCHES ARE COMPUTED BY DIFFERENT EDPS. | ||
*/ | ||
public final class VidSampler { | ||
|
||
// The mantissa of an IEEE 754 float is 23 bits. | ||
private static final int Ieee754MantissaMask = 0x7f_ffff; | ||
|
||
// A divisor that is used to convert the lower 23 bits of the | ||
// fingerprinted value to a floating point value between 0 and 1. | ||
private static final float maskDivisor = 1.0f / (float) Ieee754MantissaMask; | ||
|
||
// The hash function that will be used for this VidSampler. | ||
private final HashFunction hashFunction; | ||
|
||
/** | ||
* Constructs a new VidSampler. | ||
* | ||
* <p>Example usage: vidSampler = new VidSampler(Hashing.farmHashFingerprint64); | ||
* | ||
* @param hashFunction: The HashFunction that will be used for hashing VIDs. It is assumed that | ||
* the hash function generates hashed values whose lower 23 bits are approximately uniformly | ||
* distributed in the unit interval. Farmhash Fingerprint64 is a reasonable choice because it | ||
* can be computed efficiently and the hash value for a given input is guaranteed to be the | ||
* same across platforms. | ||
*/ | ||
public VidSampler(HashFunction hashFunction) { | ||
this.hashFunction = hashFunction; | ||
} | ||
|
||
/** Hashes a vid to a real number in the interval [0, 1]. */ | ||
public float hashVidToUnitInterval(long vid) { | ||
Hasher vidHasher = hashFunction.newHasher(); | ||
return maskDivisor * (float) (vidHasher.putLong(vid).hash().asInt() & Ieee754MantissaMask); | ||
} | ||
|
||
/** | ||
* Returns true if the hashed VID is in the range from samplingIntervalStart to | ||
* samplingIntervalStart + samplingIntervalWidth, including wrap-around. | ||
* | ||
* @param vid The VID that is to be checked. | ||
* @param samplingIntervalStart The left endpoint of the VID sampling interval. | ||
* @param samplingIntervalEnd The right endpoint of the VID sampling interval. | ||
* @return True if the hashed VID is in the interval from samplingIntervalStart | ||
*/ | ||
public boolean vidIsInSamplingBucket( | ||
long vid, float samplingIntervalStart, float samplingIntervalWidth) { | ||
float hashedVid = hashVidToUnitInterval(vid); | ||
final float samplingIntervalEnd = samplingIntervalStart + samplingIntervalWidth; | ||
|
||
return ((samplingIntervalStart <= hashedVid && hashedVid < samplingIntervalEnd) | ||
|| (hashedVid < samplingIntervalEnd - 1.0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
load("@rules_java//java:defs.bzl", "java_test") | ||
|
||
java_test( | ||
name = "VidSamplerTest", | ||
srcs = [ | ||
"VidSamplerTest.java", | ||
], | ||
test_class = "org.wfanet.estimation.VidSamplerTest", | ||
deps = [ | ||
"//src/main/java/org/wfanet/sampling:vid_sampler", | ||
"@maven//:com_google_truth_truth", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright 2022 The Cross-Media Measurement Authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package org.wfanet.estimation; | ||
|
||
import static com.google.common.truth.Truth.assertThat; | ||
import static com.google.common.truth.Truth.assertWithMessage; | ||
|
||
import com.google.common.hash.Hashing; | ||
import java.io.IOException; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
import org.junit.runners.JUnit4; | ||
|
||
@RunWith(JUnit4.class) | ||
public class VidSamplerTest { | ||
|
||
@Test | ||
public void testHashVidToUnitIntervalSpecificValues() throws IOException { | ||
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64()); | ||
assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f); | ||
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f); | ||
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f); | ||
} | ||
|
||
@Test | ||
public void testHashVidToUnitIntervalConsistency() throws IOException { | ||
// Tests that if the same value is hashed at different times, | ||
// the same result is returned. | ||
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64()); | ||
assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f); | ||
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f); | ||
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f); | ||
|
||
assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f); | ||
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f); | ||
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f); | ||
|
||
assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f); | ||
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f); | ||
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f); | ||
} | ||
|
||
@Test | ||
public void testHashVidToUnitIntervalValuesInRange() throws IOException { | ||
// Tests that values returned by VID hasher are between 0 and 1. | ||
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64()); | ||
for (long i = 0L; i < 1000; i++) { | ||
float vid = vidTester.hashVidToUnitInterval(i); | ||
assertWithMessage("vid %s hashes to %s", i, vid).that((0.0 <= vid) && (vid <= 1.0)).isTrue(); | ||
} | ||
} | ||
|
||
@Test | ||
public void testHashVidToUnitIntervalChiSquaredDistribution() throws IOException { | ||
// Tests that when a large number of samples are drawn, the distribution | ||
// passes the chi-squared goodness of fit test. | ||
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64()); | ||
final int NSAMPLES = 1000; | ||
int[] buckets = new int[10]; | ||
for (long i = 0L; i < NSAMPLES; i++) { | ||
float vid = vidTester.hashVidToUnitInterval(i); | ||
int bucket_id = (int) (vid * buckets.length); | ||
buckets[bucket_id]++; | ||
} | ||
|
||
double chi_square_statistic = 0.0; | ||
double expected = (double) NSAMPLES / (double) buckets.length; | ||
for (int i = 0; i < buckets.length; i++) { | ||
double error = buckets[i] - expected; | ||
chi_square_statistic += error * error / expected; | ||
} | ||
|
||
// 16.91 is the 95th percentile of the chi-squared distribution with | ||
// 9 degrees of freedom. | ||
assertThat(chi_square_statistic).isLessThan(16.91); | ||
} | ||
|
||
@Test | ||
public void testVidIsInSamplingBucket() throws IOException { | ||
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64()); | ||
|
||
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f); | ||
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.5f, 0.1f)).isTrue(); | ||
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.5f, 0.01f)).isFalse(); | ||
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.55f, 0.1f)).isFalse(); | ||
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.9f, 0.6f)).isFalse(); | ||
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.9f, 0.7f)).isTrue(); | ||
} | ||
} |