From 3f12c7df0ca4ce06c92c680966882b3b184b66b2 Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sat, 23 Mar 2024 19:15:05 -0500 Subject: [PATCH] update b_bit_minhash benchmark --- benchmark/sketches/b_bit_minhash_benchmark.py | 185 +++++++++++------- 1 file changed, 110 insertions(+), 75 deletions(-) diff --git a/benchmark/sketches/b_bit_minhash_benchmark.py b/benchmark/sketches/b_bit_minhash_benchmark.py index 08c40f8f..2a98bfbb 100644 --- a/benchmark/sketches/b_bit_minhash_benchmark.py +++ b/benchmark/sketches/b_bit_minhash_benchmark.py @@ -1,77 +1,112 @@ -''' -Benchmarking the performance and accuracy of b-bi MinHash. -''' -import time, logging, random -logging.basicConfig(level=logging.INFO) -import pyhash -import numpy as np +""" +Benchmarking the performance and accuracy of b-bit MinHash. +""" +import time, logging +from numpy import random +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt from datasketch.minhash import MinHash from datasketch.b_bit_minhash import bBitMinHash -from similarity_benchmark import _get_exact, _gen_data,\ - Hash, _b_bit_minhash_jaccard - -def _run_minhash(A, B, data, seed, bs, num_perm): - (a_start, a_end), (b_start, b_end) = A, B - hasher = pyhash.murmur3_32() - m1 = MinHash(num_perm=num_perm, hashobj=Hash) - m2 = MinHash(num_perm=num_perm, hashobj=Hash) - for i in xrange(a_start, a_end): - m1.update(hasher(data[i], seed=seed)) - for i in xrange(b_start, b_end): - m2.update(hasher(data[i], seed=seed)) - return [m1.jaccard(m2)] + \ - [_b_bit_minhash_jaccard(m1, m2, b) for b in bs] - -def _run_test(A, B, data, n, bs, num_perm): - logging.info("Run tests with A = (%d, %d), B = (%d, %d), n = %d" - % (A[0], A[1], B[0], B[1], n)) - runs = np.array([_run_minhash(A, B, data, i, bs, num_perm) - for i in xrange(n)]).T - return runs - -def run_full_tests(attr_pairs, data, n, bs, num_perm): - return [_run_test(A, B, data, n, bs, num_perm) - for A, B in attr_pairs] - -def plot(result, bs, exact_sims, num_perm, bins, save): - import matplotlib - matplotlib.use('Agg') - import matplotlib.pyplot as plt - num_row = 1 - num_col = len(result) - basesize = 5 - size = (basesize*num_col, basesize*num_row) - fig, axes = plt.subplots(num_row, num_col, sharey=True, - sharex=True, figsize=size) - for i, runs in enumerate(result): - minhash = sorted(runs[0]) - bbits = [sorted(r) for r in runs[1:]] - exact_sim = exact_sims[i] - ax = axes[i] - l = ax.plot(minhash, label='MinHash') - for b, run in zip(bs, bbits): - l = ax.plot(run, label='%d-bit' % b) - ax.axhline(exact_sim, color='black', linestyle='--', label='Exact') - ax.set_title("%d perm funcs, exact = %.2f" % (num_perm, exact_sim)) - ax.grid() - ax.set_xlabel("Runs with random hash functions") - if i == 0: - ax.set_ylabel('Jaccard') - if i == num_col - 1: - ax.legend(loc='lower right') - fig.savefig(save) - - -if __name__ == "__main__": - data = _gen_data(5000) - attr_pairs = [((0, 3000), (2000, 5000)), - ((0, 3500), (1500, 5000)), - ((0, 4500), (500, 5000))] - num_perm = 128 - bs = [1, 2, 3] - n = 100 - save = "b_bit_minhash_benchmark.png" - bins = [i*0.02 for i in range(51)] - exact_sims = [_get_exact(A, B) for A, B in attr_pairs] - result = run_full_tests(attr_pairs, data, n, bs, num_perm) - plot(result, bs, exact_sims, num_perm, bins, save) +from datasketch.hashfunc import * + +logging.basicConfig(level=logging.INFO) + +# Produce some bytes +int_bytes = lambda x: ("a-%d-%d" % (x, x)).encode("utf-8") + + +def run_perf(card, num_perm, num_bits): + dur = 0 + n_trials = 5 + for i in range(n_trials): + m = MinHash(num_perm=num_perm) + logging.info("MinHash using %d permutation functions" % num_perm) + start = time.perf_counter() + for i in range(card): + m.update(int_bytes(i)) + + b = bBitMinHash(m, num_bits) + duration = time.perf_counter() - start + dur += duration + logging.info("Digested %d hashes in %.4f sec" % (card, duration)) + return dur / n_trials + + +def _run_acc(size, seed, num_perm, num_bits): + m = MinHash(num_perm=num_perm) + s = set() + random.seed(seed) + for i in range(size): + v = int_bytes(random.randint(1, size)) + m.update(v) + s.add(v) + + b = bBitMinHash(m, num_bits) + return (b, s) + + +def run_acc(size, num_perm, num_bits): + logging.info("MinHash using %d permutation functions" % num_perm) + m1, s1 = _run_acc(size, 1, num_perm, num_bits) + m2, s2 = _run_acc(size, 4, num_perm, num_bits) + j = float(len(s1.intersection(s2))) / float(len(s1.union(s2))) + j_e = m1.jaccard(m2) + err = abs(j - j_e) + return err + + +num_perms = range(10, 256, 20) +num_bits = [1, 2, 3, 4, 8, 12, 16, 32] +bit_colors = colors = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", +] +output = "b_bit_minhash_benchmark.png" + +logging.info("> Running performance tests") +card = 5000 +perf_times = {} +for b in num_bits: + run_times = [run_perf(card, n, b) for n in num_perms] + perf_times[b] = run_times + + +logging.info("> Running accuracy tests") +size = 5000 +errors = {} +for b in num_bits: + errs = [run_acc(size, n, b) for n in num_perms] + errors[b] = errs + +logging.info("> Plotting result") +fig, axe = plt.subplots(1, 2, sharex=True, figsize=(10, 4)) +ax = axe[1] +for i, b in enumerate(num_bits): + ax.plot( + num_perms, perf_times[b], marker="+", color=bit_colors[i], label=f"{b} bits" + ) +ax.set_xlabel("Number of permutation functions") +ax.set_ylabel("Running time (sec)") +ax.set_title("MinHash performance") +ax.grid() +ax.legend() +ax = axe[0] +for i, b in enumerate(num_bits): + ax.plot(num_perms, errors[b], marker="+", color=bit_colors[i], label=f"{b} bits") +ax.set_xlabel("Number of permutation functions") +ax.set_ylabel("Absolute error in Jaccard estimation") +ax.set_title("MinHash accuracy") +ax.grid() +ax.legend() + +plt.tight_layout() +fig.savefig(output) +logging.info("Plot saved to %s" % output)