-
Notifications
You must be signed in to change notification settings - Fork 297
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a532f06
commit 3f12c7d
Showing
1 changed file
with
110 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,112 @@ | ||
''' | ||
Benchmarking the performance and accuracy of b-bi MinHash. | ||
''' | ||
import time, logging, random | ||
logging.basicConfig(level=logging.INFO) | ||
import pyhash | ||
import numpy as np | ||
""" | ||
Benchmarking the performance and accuracy of b-bit MinHash. | ||
""" | ||
import time, logging | ||
from numpy import random | ||
import matplotlib | ||
|
||
matplotlib.use("Agg") | ||
import matplotlib.pyplot as plt | ||
from datasketch.minhash import MinHash | ||
from datasketch.b_bit_minhash import bBitMinHash | ||
from similarity_benchmark import _get_exact, _gen_data,\ | ||
Hash, _b_bit_minhash_jaccard | ||
|
||
def _run_minhash(A, B, data, seed, bs, num_perm): | ||
(a_start, a_end), (b_start, b_end) = A, B | ||
hasher = pyhash.murmur3_32() | ||
m1 = MinHash(num_perm=num_perm, hashobj=Hash) | ||
m2 = MinHash(num_perm=num_perm, hashobj=Hash) | ||
for i in xrange(a_start, a_end): | ||
m1.update(hasher(data[i], seed=seed)) | ||
for i in xrange(b_start, b_end): | ||
m2.update(hasher(data[i], seed=seed)) | ||
return [m1.jaccard(m2)] + \ | ||
[_b_bit_minhash_jaccard(m1, m2, b) for b in bs] | ||
|
||
def _run_test(A, B, data, n, bs, num_perm): | ||
logging.info("Run tests with A = (%d, %d), B = (%d, %d), n = %d" | ||
% (A[0], A[1], B[0], B[1], n)) | ||
runs = np.array([_run_minhash(A, B, data, i, bs, num_perm) | ||
for i in xrange(n)]).T | ||
return runs | ||
|
||
def run_full_tests(attr_pairs, data, n, bs, num_perm): | ||
return [_run_test(A, B, data, n, bs, num_perm) | ||
for A, B in attr_pairs] | ||
|
||
def plot(result, bs, exact_sims, num_perm, bins, save): | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
num_row = 1 | ||
num_col = len(result) | ||
basesize = 5 | ||
size = (basesize*num_col, basesize*num_row) | ||
fig, axes = plt.subplots(num_row, num_col, sharey=True, | ||
sharex=True, figsize=size) | ||
for i, runs in enumerate(result): | ||
minhash = sorted(runs[0]) | ||
bbits = [sorted(r) for r in runs[1:]] | ||
exact_sim = exact_sims[i] | ||
ax = axes[i] | ||
l = ax.plot(minhash, label='MinHash') | ||
for b, run in zip(bs, bbits): | ||
l = ax.plot(run, label='%d-bit' % b) | ||
ax.axhline(exact_sim, color='black', linestyle='--', label='Exact') | ||
ax.set_title("%d perm funcs, exact = %.2f" % (num_perm, exact_sim)) | ||
ax.grid() | ||
ax.set_xlabel("Runs with random hash functions") | ||
if i == 0: | ||
ax.set_ylabel('Jaccard') | ||
if i == num_col - 1: | ||
ax.legend(loc='lower right') | ||
fig.savefig(save) | ||
|
||
|
||
if __name__ == "__main__": | ||
data = _gen_data(5000) | ||
attr_pairs = [((0, 3000), (2000, 5000)), | ||
((0, 3500), (1500, 5000)), | ||
((0, 4500), (500, 5000))] | ||
num_perm = 128 | ||
bs = [1, 2, 3] | ||
n = 100 | ||
save = "b_bit_minhash_benchmark.png" | ||
bins = [i*0.02 for i in range(51)] | ||
exact_sims = [_get_exact(A, B) for A, B in attr_pairs] | ||
result = run_full_tests(attr_pairs, data, n, bs, num_perm) | ||
plot(result, bs, exact_sims, num_perm, bins, save) | ||
from datasketch.hashfunc import * | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
# Produce some bytes | ||
int_bytes = lambda x: ("a-%d-%d" % (x, x)).encode("utf-8") | ||
|
||
|
||
def run_perf(card, num_perm, num_bits): | ||
dur = 0 | ||
n_trials = 5 | ||
for i in range(n_trials): | ||
m = MinHash(num_perm=num_perm) | ||
logging.info("MinHash using %d permutation functions" % num_perm) | ||
start = time.perf_counter() | ||
for i in range(card): | ||
m.update(int_bytes(i)) | ||
|
||
b = bBitMinHash(m, num_bits) | ||
duration = time.perf_counter() - start | ||
dur += duration | ||
logging.info("Digested %d hashes in %.4f sec" % (card, duration)) | ||
return dur / n_trials | ||
|
||
|
||
def _run_acc(size, seed, num_perm, num_bits): | ||
m = MinHash(num_perm=num_perm) | ||
s = set() | ||
random.seed(seed) | ||
for i in range(size): | ||
v = int_bytes(random.randint(1, size)) | ||
m.update(v) | ||
s.add(v) | ||
|
||
b = bBitMinHash(m, num_bits) | ||
return (b, s) | ||
|
||
|
||
def run_acc(size, num_perm, num_bits): | ||
logging.info("MinHash using %d permutation functions" % num_perm) | ||
m1, s1 = _run_acc(size, 1, num_perm, num_bits) | ||
m2, s2 = _run_acc(size, 4, num_perm, num_bits) | ||
j = float(len(s1.intersection(s2))) / float(len(s1.union(s2))) | ||
j_e = m1.jaccard(m2) | ||
err = abs(j - j_e) | ||
return err | ||
|
||
|
||
num_perms = range(10, 256, 20) | ||
num_bits = [1, 2, 3, 4, 8, 12, 16, 32] | ||
bit_colors = colors = [ | ||
"#1f77b4", | ||
"#ff7f0e", | ||
"#2ca02c", | ||
"#d62728", | ||
"#9467bd", | ||
"#8c564b", | ||
"#e377c2", | ||
"#7f7f7f", | ||
] | ||
output = "b_bit_minhash_benchmark.png" | ||
|
||
logging.info("> Running performance tests") | ||
card = 5000 | ||
perf_times = {} | ||
for b in num_bits: | ||
run_times = [run_perf(card, n, b) for n in num_perms] | ||
perf_times[b] = run_times | ||
|
||
|
||
logging.info("> Running accuracy tests") | ||
size = 5000 | ||
errors = {} | ||
for b in num_bits: | ||
errs = [run_acc(size, n, b) for n in num_perms] | ||
errors[b] = errs | ||
|
||
logging.info("> Plotting result") | ||
fig, axe = plt.subplots(1, 2, sharex=True, figsize=(10, 4)) | ||
ax = axe[1] | ||
for i, b in enumerate(num_bits): | ||
ax.plot( | ||
num_perms, perf_times[b], marker="+", color=bit_colors[i], label=f"{b} bits" | ||
) | ||
ax.set_xlabel("Number of permutation functions") | ||
ax.set_ylabel("Running time (sec)") | ||
ax.set_title("MinHash performance") | ||
ax.grid() | ||
ax.legend() | ||
ax = axe[0] | ||
for i, b in enumerate(num_bits): | ||
ax.plot(num_perms, errors[b], marker="+", color=bit_colors[i], label=f"{b} bits") | ||
ax.set_xlabel("Number of permutation functions") | ||
ax.set_ylabel("Absolute error in Jaccard estimation") | ||
ax.set_title("MinHash accuracy") | ||
ax.grid() | ||
ax.legend() | ||
|
||
plt.tight_layout() | ||
fig.savefig(output) | ||
logging.info("Plot saved to %s" % output) |