Skip to content

Commit

Permalink
update b_bit_minhash benchmark (ekzhu#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
123epsilon committed Mar 30, 2024
1 parent a532f06 commit f8269f6
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 75 deletions.
185 changes: 110 additions & 75 deletions benchmark/sketches/b_bit_minhash_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,112 @@
'''
Benchmarking the performance and accuracy of b-bi MinHash.
'''
import time, logging, random
logging.basicConfig(level=logging.INFO)
import pyhash
import numpy as np
"""
Benchmarking the performance and accuracy of b-bit MinHash.
"""
import time, logging
from numpy import random
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datasketch.minhash import MinHash
from datasketch.b_bit_minhash import bBitMinHash
from similarity_benchmark import _get_exact, _gen_data,\
Hash, _b_bit_minhash_jaccard

def _run_minhash(A, B, data, seed, bs, num_perm):
(a_start, a_end), (b_start, b_end) = A, B
hasher = pyhash.murmur3_32()
m1 = MinHash(num_perm=num_perm, hashobj=Hash)
m2 = MinHash(num_perm=num_perm, hashobj=Hash)
for i in xrange(a_start, a_end):
m1.update(hasher(data[i], seed=seed))
for i in xrange(b_start, b_end):
m2.update(hasher(data[i], seed=seed))
return [m1.jaccard(m2)] + \
[_b_bit_minhash_jaccard(m1, m2, b) for b in bs]

def _run_test(A, B, data, n, bs, num_perm):
logging.info("Run tests with A = (%d, %d), B = (%d, %d), n = %d"
% (A[0], A[1], B[0], B[1], n))
runs = np.array([_run_minhash(A, B, data, i, bs, num_perm)
for i in xrange(n)]).T
return runs

def run_full_tests(attr_pairs, data, n, bs, num_perm):
return [_run_test(A, B, data, n, bs, num_perm)
for A, B in attr_pairs]

def plot(result, bs, exact_sims, num_perm, bins, save):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
num_row = 1
num_col = len(result)
basesize = 5
size = (basesize*num_col, basesize*num_row)
fig, axes = plt.subplots(num_row, num_col, sharey=True,
sharex=True, figsize=size)
for i, runs in enumerate(result):
minhash = sorted(runs[0])
bbits = [sorted(r) for r in runs[1:]]
exact_sim = exact_sims[i]
ax = axes[i]
l = ax.plot(minhash, label='MinHash')
for b, run in zip(bs, bbits):
l = ax.plot(run, label='%d-bit' % b)
ax.axhline(exact_sim, color='black', linestyle='--', label='Exact')
ax.set_title("%d perm funcs, exact = %.2f" % (num_perm, exact_sim))
ax.grid()
ax.set_xlabel("Runs with random hash functions")
if i == 0:
ax.set_ylabel('Jaccard')
if i == num_col - 1:
ax.legend(loc='lower right')
fig.savefig(save)


if __name__ == "__main__":
data = _gen_data(5000)
attr_pairs = [((0, 3000), (2000, 5000)),
((0, 3500), (1500, 5000)),
((0, 4500), (500, 5000))]
num_perm = 128
bs = [1, 2, 3]
n = 100
save = "b_bit_minhash_benchmark.png"
bins = [i*0.02 for i in range(51)]
exact_sims = [_get_exact(A, B) for A, B in attr_pairs]
result = run_full_tests(attr_pairs, data, n, bs, num_perm)
plot(result, bs, exact_sims, num_perm, bins, save)
from datasketch.hashfunc import *

logging.basicConfig(level=logging.INFO)

# Produce some bytes
int_bytes = lambda x: ("a-%d-%d" % (x, x)).encode("utf-8")


def run_perf(card, num_perm, num_bits):
dur = 0
n_trials = 5
for i in range(n_trials):
m = MinHash(num_perm=num_perm)
logging.info("MinHash using %d permutation functions" % num_perm)
start = time.perf_counter()
for i in range(card):
m.update(int_bytes(i))

b = bBitMinHash(m, num_bits)
duration = time.perf_counter() - start
dur += duration
logging.info("Digested %d hashes in %.4f sec" % (card, duration))
return dur / n_trials


def _run_acc(size, seed, num_perm, num_bits):
m = MinHash(num_perm=num_perm)
s = set()
random.seed(seed)
for i in range(size):
v = int_bytes(random.randint(1, size))
m.update(v)
s.add(v)

b = bBitMinHash(m, num_bits)
return (b, s)


def run_acc(size, num_perm, num_bits):
logging.info("MinHash using %d permutation functions" % num_perm)
m1, s1 = _run_acc(size, 1, num_perm, num_bits)
m2, s2 = _run_acc(size, 4, num_perm, num_bits)
j = float(len(s1.intersection(s2))) / float(len(s1.union(s2)))
j_e = m1.jaccard(m2)
err = abs(j - j_e)
return err


num_perms = range(10, 256, 20)
num_bits = [1, 2, 3, 4, 8, 12, 16, 32]
bit_colors = colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
]
output = "b_bit_minhash_benchmark.png"

logging.info("> Running performance tests")
card = 5000
perf_times = {}
for b in num_bits:
run_times = [run_perf(card, n, b) for n in num_perms]
perf_times[b] = run_times


logging.info("> Running accuracy tests")
size = 5000
errors = {}
for b in num_bits:
errs = [run_acc(size, n, b) for n in num_perms]
errors[b] = errs

logging.info("> Plotting result")
fig, axe = plt.subplots(1, 2, sharex=True, figsize=(10, 4))
ax = axe[1]
for i, b in enumerate(num_bits):
ax.plot(
num_perms, perf_times[b], marker="+", color=bit_colors[i], label=f"{b} bits"
)
ax.set_xlabel("Number of permutation functions")
ax.set_ylabel("Running time (sec)")
ax.set_title("MinHash performance")
ax.grid()
ax.legend()
ax = axe[0]
for i, b in enumerate(num_bits):
ax.plot(num_perms, errors[b], marker="+", color=bit_colors[i], label=f"{b} bits")
ax.set_xlabel("Number of permutation functions")
ax.set_ylabel("Absolute error in Jaccard estimation")
ax.set_title("MinHash accuracy")
ax.grid()
ax.legend()

plt.tight_layout()
fig.savefig(output)
logging.info("Plot saved to %s" % output)
134 changes: 134 additions & 0 deletions benchmark/sketches/b_bit_minhash_benchmark_wikipedia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Benchmarking the performance and accuracy of b-bit MinHash.
"""
import time, logging
from numpy import random
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datasketch.minhash import MinHash
from datasketch.b_bit_minhash import bBitMinHash
from datasketch.hashfunc import *
from datasets import load_dataset

logging.basicConfig(level=logging.INFO)

# Produce some bytes
# int_bytes = lambda x: ("a-%d-%d" % (x, x)).encode("utf-8")

wiki_data = load_dataset("wikipedia", "20220301.simple")
N_DOCS = len(wiki_data['train'])
N_TRIALS = 30

def run_perf(card, num_perm, num_bits):
dur = 0
for i in range(N_TRIALS):
doc = wiki_data['train'][random.randint(0, N_DOCS)]['text']
m = MinHash(num_perm=num_perm)
logging.info("MinHash using %d permutation functions" % num_perm)
start = time.perf_counter()
# get real document data, but upper bound the size to evaluate runtime performance
s = set(doc.split()[:card])
for d in s:
m.update(d.encode("utf8"))

b = bBitMinHash(m, num_bits)
duration = time.perf_counter() - start
dur += duration
logging.info("Digested %d hashes in %.4f sec" % (card, duration))
return dur / N_TRIALS


# takes a document string as first argument
def _run_acc(doc, num_perm, num_bits):
m = MinHash(num_perm=num_perm)
s = set(doc.split())
for d in s:
m.update(d.encode("utf8"))
b = bBitMinHash(m, num_bits)
return (b, s)


def run_acc(num_perm, num_bits):
logging.info("MinHash using %d permutation functions" % num_perm)
avg_err = 0
avg_jaccard = 0
for i in range(N_TRIALS):
random.seed(i+1)
i1 = random.randint(0, N_DOCS)

overlap = random.uniform()
doc1 = wiki_data['train'][i1]['text']
# generate a random overlapping region of text given a start point
# this isn't perfect since we could be cutting off the start/ending words in the region
# but that won't affect the estimation too much
overlap_size = int(len(doc1)*overlap)
overlap_start = random.randint(0, len(doc1)-overlap_size)
doc2 = wiki_data['train'][i1]['text'][overlap_start:overlap_start+overlap_size]
m1, s1 = _run_acc(doc1, num_perm, num_bits)
m2, s2 = _run_acc(doc2, num_perm, num_bits)
j = float(len(s1.intersection(s2))) / float(len(s1.union(s2)))
avg_jaccard += j
j_e = m1.jaccard(m2)
err = abs(j - j_e)
logging.info(f"Jaccard Similarity for identical document with {overlap*100:.2f}% overlap: {wiki_data['train'][i1]['title']}= {j} / estimate={j_e}")
avg_err += err
avg_err /= N_TRIALS
avg_jaccard /= N_TRIALS
logging.info(f"Average True Jaccard Sim: {avg_jaccard}")
return avg_err


num_perms = range(10, 512, 20)
num_bits = [1, 2, 3, 4, 8, 12, 16, 32]
bit_colors = colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
]
output = "b_bit_minhash_wikisimple_benchmark.png"

logging.info("> Running performance tests")
card = 200
perf_times = {}
for b in num_bits:
run_times = [run_perf(card, n, b) for n in num_perms]
perf_times[b] = run_times


logging.info("> Running accuracy tests")
errors = {}
for b in num_bits:
errs = [run_acc(n, b) for n in num_perms]
errors[b] = errs

logging.info("> Plotting result")
fig, axe = plt.subplots(1, 2, sharex=True, figsize=(10, 4))
ax = axe[1]
for i, b in enumerate(num_bits):
ax.plot(
num_perms, perf_times[b], marker="+", color=bit_colors[i], label=f"{b} bits"
)
ax.set_xlabel("Number of permutation functions")
ax.set_ylabel("Running time (sec)")
ax.set_title("MinHash performance")
ax.grid()
ax.legend()
ax = axe[0]
for i, b in enumerate(num_bits):
ax.plot(num_perms, errors[b], marker="+", color=bit_colors[i], label=f"{b} bits")
ax.set_xlabel("Number of permutation functions")
ax.set_ylabel("Absolute error in Jaccard estimation")
ax.set_title("MinHash accuracy")
ax.grid()
ax.legend()

plt.tight_layout()
fig.savefig(output)
logging.info("Plot saved to %s" % output)

0 comments on commit f8269f6

Please sign in to comment.