Skip to content

Commit

Permalink
Add code for testing if a VID is a member of a sampling bucket. (#8)
Browse files Browse the repository at this point in the history
* Adds code for testing if a VID is a member of a sampling bucket.

* Updates .bazelversion to work around apple_common error

* Reformats build file

* Adds unit test for checking that hashed VIDs are between 0 and 1

* Updates in response to comments from Sanjay

* Updates in response to comments from Craig Wright.

* Updated in response to comments from Craig Wright and Sanjay Vasandani.

* Changed a stray 52 to 23.

* Updates per comments from Craig Wright
  • Loading branch information
matthewclegg authored Feb 18, 2022
1 parent c8ab0f3 commit a89f9cc
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions .bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
4.2.2
13 changes: 13 additions & 0 deletions src/main/java/org/wfanet/sampling/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@rules_java//java:defs.bzl", "java_library")

package(default_visibility = ["//visibility:public"])

java_library(
name = "vid_sampler",
srcs = [
"VidSampler.java",
],
deps = [
"@maven//:com_google_guava_guava",
],
)
76 changes: 76 additions & 0 deletions src/main/java/org/wfanet/sampling/VidSampler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2022 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.estimation;

import com.google.common.hash.HashFunction;
import com.google.common.hash.Hasher;

/**
* Filter VIDs by sampling bucket.
*
* <p>NOTE: IT IS IMPORTANT THAT THE SAME HASHING FUNCTION BE USED FOR ANY TWO SKETCHES THAT MIGHT
* BE COMBINED BY THE SMPC, INCLUDING THE CASE WHEN THOSE SKETCHES ARE COMPUTED BY DIFFERENT EDPS.
*/
public final class VidSampler {

// The mantissa of an IEEE 754 float is 23 bits.
private static final int Ieee754MantissaMask = 0x7f_ffff;

// A divisor that is used to convert the lower 23 bits of the
// fingerprinted value to a floating point value between 0 and 1.
private static final float maskDivisor = 1.0f / (float) Ieee754MantissaMask;

// The hash function that will be used for this VidSampler.
private final HashFunction hashFunction;

/**
* Constructs a new VidSampler.
*
* <p>Example usage: vidSampler = new VidSampler(Hashing.farmHashFingerprint64);
*
* @param hashFunction: The HashFunction that will be used for hashing VIDs. It is assumed that
* the hash function generates hashed values whose lower 23 bits are approximately uniformly
* distributed in the unit interval. Farmhash Fingerprint64 is a reasonable choice because it
* can be computed efficiently and the hash value for a given input is guaranteed to be the
* same across platforms.
*/
public VidSampler(HashFunction hashFunction) {
this.hashFunction = hashFunction;
}

/** Hashes a vid to a real number in the interval [0, 1]. */
public float hashVidToUnitInterval(long vid) {
Hasher vidHasher = hashFunction.newHasher();
return maskDivisor * (float) (vidHasher.putLong(vid).hash().asInt() & Ieee754MantissaMask);
}

/**
* Returns true if the hashed VID is in the range from samplingIntervalStart to
* samplingIntervalStart + samplingIntervalWidth, including wrap-around.
*
* @param vid The VID that is to be checked.
* @param samplingIntervalStart The left endpoint of the VID sampling interval.
* @param samplingIntervalEnd The right endpoint of the VID sampling interval.
* @return True if the hashed VID is in the interval from samplingIntervalStart
*/
public boolean vidIsInSamplingBucket(
long vid, float samplingIntervalStart, float samplingIntervalWidth) {
float hashedVid = hashVidToUnitInterval(vid);
final float samplingIntervalEnd = samplingIntervalStart + samplingIntervalWidth;

return ((samplingIntervalStart <= hashedVid && hashedVid < samplingIntervalEnd)
|| (hashedVid < samplingIntervalEnd - 1.0));
}
}
13 changes: 13 additions & 0 deletions src/test/java/org/wfanet/sampling/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@rules_java//java:defs.bzl", "java_test")

java_test(
name = "VidSamplerTest",
srcs = [
"VidSamplerTest.java",
],
test_class = "org.wfanet.estimation.VidSamplerTest",
deps = [
"//src/main/java/org/wfanet/sampling:vid_sampler",
"@maven//:com_google_truth_truth",
],
)
101 changes: 101 additions & 0 deletions src/test/java/org/wfanet/sampling/VidSamplerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2022 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.estimation;

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;

import com.google.common.hash.Hashing;
import java.io.IOException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class VidSamplerTest {

@Test
public void testHashVidToUnitIntervalSpecificValues() throws IOException {
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64());
assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f);
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f);
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f);
}

@Test
public void testHashVidToUnitIntervalConsistency() throws IOException {
// Tests that if the same value is hashed at different times,
// the same result is returned.
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64());
assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f);
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f);
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f);

assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f);
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f);
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f);

assertThat(vidTester.hashVidToUnitInterval(1L)).isWithin(0.001f).of(0.4637f);
assertThat(vidTester.hashVidToUnitInterval(2L)).isWithin(0.001f).of(0.9473f);
assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f);
}

@Test
public void testHashVidToUnitIntervalValuesInRange() throws IOException {
// Tests that values returned by VID hasher are between 0 and 1.
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64());
for (long i = 0L; i < 1000; i++) {
float vid = vidTester.hashVidToUnitInterval(i);
assertWithMessage("vid %s hashes to %s", i, vid).that((0.0 <= vid) && (vid <= 1.0)).isTrue();
}
}

@Test
public void testHashVidToUnitIntervalChiSquaredDistribution() throws IOException {
// Tests that when a large number of samples are drawn, the distribution
// passes the chi-squared goodness of fit test.
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64());
final int NSAMPLES = 1000;
int[] buckets = new int[10];
for (long i = 0L; i < NSAMPLES; i++) {
float vid = vidTester.hashVidToUnitInterval(i);
int bucket_id = (int) (vid * buckets.length);
buckets[bucket_id]++;
}

double chi_square_statistic = 0.0;
double expected = (double) NSAMPLES / (double) buckets.length;
for (int i = 0; i < buckets.length; i++) {
double error = buckets[i] - expected;
chi_square_statistic += error * error / expected;
}

// 16.91 is the 95th percentile of the chi-squared distribution with
// 9 degrees of freedom.
assertThat(chi_square_statistic).isLessThan(16.91);
}

@Test
public void testVidIsInSamplingBucket() throws IOException {
VidSampler vidTester = new VidSampler(Hashing.farmHashFingerprint64());

assertThat(vidTester.hashVidToUnitInterval(3L)).isWithin(0.001f).of(0.5410f);
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.5f, 0.1f)).isTrue();
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.5f, 0.01f)).isFalse();
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.55f, 0.1f)).isFalse();
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.9f, 0.6f)).isFalse();
assertThat(vidTester.vidIsInSamplingBucket(3L, 0.9f, 0.7f)).isTrue();
}
}

0 comments on commit a89f9cc

Please sign in to comment.