forked from ekzhu/datasketch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update b_bit_minhash benchmark (ekzhu#238)
- Loading branch information
1 parent
a532f06
commit f8269f6
Showing
2 changed files
with
244 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,112 @@ | ||
''' | ||
Benchmarking the performance and accuracy of b-bi MinHash. | ||
''' | ||
import time, logging, random | ||
logging.basicConfig(level=logging.INFO) | ||
import pyhash | ||
import numpy as np | ||
""" | ||
Benchmarking the performance and accuracy of b-bit MinHash. | ||
""" | ||
import time, logging | ||
from numpy import random | ||
import matplotlib | ||
|
||
matplotlib.use("Agg") | ||
import matplotlib.pyplot as plt | ||
from datasketch.minhash import MinHash | ||
from datasketch.b_bit_minhash import bBitMinHash | ||
from similarity_benchmark import _get_exact, _gen_data,\ | ||
Hash, _b_bit_minhash_jaccard | ||
|
||
def _run_minhash(A, B, data, seed, bs, num_perm): | ||
(a_start, a_end), (b_start, b_end) = A, B | ||
hasher = pyhash.murmur3_32() | ||
m1 = MinHash(num_perm=num_perm, hashobj=Hash) | ||
m2 = MinHash(num_perm=num_perm, hashobj=Hash) | ||
for i in xrange(a_start, a_end): | ||
m1.update(hasher(data[i], seed=seed)) | ||
for i in xrange(b_start, b_end): | ||
m2.update(hasher(data[i], seed=seed)) | ||
return [m1.jaccard(m2)] + \ | ||
[_b_bit_minhash_jaccard(m1, m2, b) for b in bs] | ||
|
||
def _run_test(A, B, data, n, bs, num_perm): | ||
logging.info("Run tests with A = (%d, %d), B = (%d, %d), n = %d" | ||
% (A[0], A[1], B[0], B[1], n)) | ||
runs = np.array([_run_minhash(A, B, data, i, bs, num_perm) | ||
for i in xrange(n)]).T | ||
return runs | ||
|
||
def run_full_tests(attr_pairs, data, n, bs, num_perm): | ||
return [_run_test(A, B, data, n, bs, num_perm) | ||
for A, B in attr_pairs] | ||
|
||
def plot(result, bs, exact_sims, num_perm, bins, save): | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
num_row = 1 | ||
num_col = len(result) | ||
basesize = 5 | ||
size = (basesize*num_col, basesize*num_row) | ||
fig, axes = plt.subplots(num_row, num_col, sharey=True, | ||
sharex=True, figsize=size) | ||
for i, runs in enumerate(result): | ||
minhash = sorted(runs[0]) | ||
bbits = [sorted(r) for r in runs[1:]] | ||
exact_sim = exact_sims[i] | ||
ax = axes[i] | ||
l = ax.plot(minhash, label='MinHash') | ||
for b, run in zip(bs, bbits): | ||
l = ax.plot(run, label='%d-bit' % b) | ||
ax.axhline(exact_sim, color='black', linestyle='--', label='Exact') | ||
ax.set_title("%d perm funcs, exact = %.2f" % (num_perm, exact_sim)) | ||
ax.grid() | ||
ax.set_xlabel("Runs with random hash functions") | ||
if i == 0: | ||
ax.set_ylabel('Jaccard') | ||
if i == num_col - 1: | ||
ax.legend(loc='lower right') | ||
fig.savefig(save) | ||
|
||
|
||
if __name__ == "__main__": | ||
data = _gen_data(5000) | ||
attr_pairs = [((0, 3000), (2000, 5000)), | ||
((0, 3500), (1500, 5000)), | ||
((0, 4500), (500, 5000))] | ||
num_perm = 128 | ||
bs = [1, 2, 3] | ||
n = 100 | ||
save = "b_bit_minhash_benchmark.png" | ||
bins = [i*0.02 for i in range(51)] | ||
exact_sims = [_get_exact(A, B) for A, B in attr_pairs] | ||
result = run_full_tests(attr_pairs, data, n, bs, num_perm) | ||
plot(result, bs, exact_sims, num_perm, bins, save) | ||
from datasketch.hashfunc import * | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
# Produce some bytes | ||
int_bytes = lambda x: ("a-%d-%d" % (x, x)).encode("utf-8") | ||
|
||
|
||
def run_perf(card, num_perm, num_bits): | ||
dur = 0 | ||
n_trials = 5 | ||
for i in range(n_trials): | ||
m = MinHash(num_perm=num_perm) | ||
logging.info("MinHash using %d permutation functions" % num_perm) | ||
start = time.perf_counter() | ||
for i in range(card): | ||
m.update(int_bytes(i)) | ||
|
||
b = bBitMinHash(m, num_bits) | ||
duration = time.perf_counter() - start | ||
dur += duration | ||
logging.info("Digested %d hashes in %.4f sec" % (card, duration)) | ||
return dur / n_trials | ||
|
||
|
||
def _run_acc(size, seed, num_perm, num_bits): | ||
m = MinHash(num_perm=num_perm) | ||
s = set() | ||
random.seed(seed) | ||
for i in range(size): | ||
v = int_bytes(random.randint(1, size)) | ||
m.update(v) | ||
s.add(v) | ||
|
||
b = bBitMinHash(m, num_bits) | ||
return (b, s) | ||
|
||
|
||
def run_acc(size, num_perm, num_bits): | ||
logging.info("MinHash using %d permutation functions" % num_perm) | ||
m1, s1 = _run_acc(size, 1, num_perm, num_bits) | ||
m2, s2 = _run_acc(size, 4, num_perm, num_bits) | ||
j = float(len(s1.intersection(s2))) / float(len(s1.union(s2))) | ||
j_e = m1.jaccard(m2) | ||
err = abs(j - j_e) | ||
return err | ||
|
||
|
||
num_perms = range(10, 256, 20) | ||
num_bits = [1, 2, 3, 4, 8, 12, 16, 32] | ||
bit_colors = colors = [ | ||
"#1f77b4", | ||
"#ff7f0e", | ||
"#2ca02c", | ||
"#d62728", | ||
"#9467bd", | ||
"#8c564b", | ||
"#e377c2", | ||
"#7f7f7f", | ||
] | ||
output = "b_bit_minhash_benchmark.png" | ||
|
||
logging.info("> Running performance tests") | ||
card = 5000 | ||
perf_times = {} | ||
for b in num_bits: | ||
run_times = [run_perf(card, n, b) for n in num_perms] | ||
perf_times[b] = run_times | ||
|
||
|
||
logging.info("> Running accuracy tests") | ||
size = 5000 | ||
errors = {} | ||
for b in num_bits: | ||
errs = [run_acc(size, n, b) for n in num_perms] | ||
errors[b] = errs | ||
|
||
logging.info("> Plotting result") | ||
fig, axe = plt.subplots(1, 2, sharex=True, figsize=(10, 4)) | ||
ax = axe[1] | ||
for i, b in enumerate(num_bits): | ||
ax.plot( | ||
num_perms, perf_times[b], marker="+", color=bit_colors[i], label=f"{b} bits" | ||
) | ||
ax.set_xlabel("Number of permutation functions") | ||
ax.set_ylabel("Running time (sec)") | ||
ax.set_title("MinHash performance") | ||
ax.grid() | ||
ax.legend() | ||
ax = axe[0] | ||
for i, b in enumerate(num_bits): | ||
ax.plot(num_perms, errors[b], marker="+", color=bit_colors[i], label=f"{b} bits") | ||
ax.set_xlabel("Number of permutation functions") | ||
ax.set_ylabel("Absolute error in Jaccard estimation") | ||
ax.set_title("MinHash accuracy") | ||
ax.grid() | ||
ax.legend() | ||
|
||
plt.tight_layout() | ||
fig.savefig(output) | ||
logging.info("Plot saved to %s" % output) |
134 changes: 134 additions & 0 deletions
134
benchmark/sketches/b_bit_minhash_benchmark_wikipedia.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
Benchmarking the performance and accuracy of b-bit MinHash. | ||
""" | ||
import time, logging | ||
from numpy import random | ||
import matplotlib | ||
|
||
matplotlib.use("Agg") | ||
import matplotlib.pyplot as plt | ||
from datasketch.minhash import MinHash | ||
from datasketch.b_bit_minhash import bBitMinHash | ||
from datasketch.hashfunc import * | ||
from datasets import load_dataset | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
# Produce some bytes | ||
# int_bytes = lambda x: ("a-%d-%d" % (x, x)).encode("utf-8") | ||
|
||
wiki_data = load_dataset("wikipedia", "20220301.simple") | ||
N_DOCS = len(wiki_data['train']) | ||
N_TRIALS = 30 | ||
|
||
def run_perf(card, num_perm, num_bits): | ||
dur = 0 | ||
for i in range(N_TRIALS): | ||
doc = wiki_data['train'][random.randint(0, N_DOCS)]['text'] | ||
m = MinHash(num_perm=num_perm) | ||
logging.info("MinHash using %d permutation functions" % num_perm) | ||
start = time.perf_counter() | ||
# get real document data, but upper bound the size to evaluate runtime performance | ||
s = set(doc.split()[:card]) | ||
for d in s: | ||
m.update(d.encode("utf8")) | ||
|
||
b = bBitMinHash(m, num_bits) | ||
duration = time.perf_counter() - start | ||
dur += duration | ||
logging.info("Digested %d hashes in %.4f sec" % (card, duration)) | ||
return dur / N_TRIALS | ||
|
||
|
||
# takes a document string as first argument | ||
def _run_acc(doc, num_perm, num_bits): | ||
m = MinHash(num_perm=num_perm) | ||
s = set(doc.split()) | ||
for d in s: | ||
m.update(d.encode("utf8")) | ||
b = bBitMinHash(m, num_bits) | ||
return (b, s) | ||
|
||
|
||
def run_acc(num_perm, num_bits): | ||
logging.info("MinHash using %d permutation functions" % num_perm) | ||
avg_err = 0 | ||
avg_jaccard = 0 | ||
for i in range(N_TRIALS): | ||
random.seed(i+1) | ||
i1 = random.randint(0, N_DOCS) | ||
|
||
overlap = random.uniform() | ||
doc1 = wiki_data['train'][i1]['text'] | ||
# generate a random overlapping region of text given a start point | ||
# this isn't perfect since we could be cutting off the start/ending words in the region | ||
# but that won't affect the estimation too much | ||
overlap_size = int(len(doc1)*overlap) | ||
overlap_start = random.randint(0, len(doc1)-overlap_size) | ||
doc2 = wiki_data['train'][i1]['text'][overlap_start:overlap_start+overlap_size] | ||
m1, s1 = _run_acc(doc1, num_perm, num_bits) | ||
m2, s2 = _run_acc(doc2, num_perm, num_bits) | ||
j = float(len(s1.intersection(s2))) / float(len(s1.union(s2))) | ||
avg_jaccard += j | ||
j_e = m1.jaccard(m2) | ||
err = abs(j - j_e) | ||
logging.info(f"Jaccard Similarity for identical document with {overlap*100:.2f}% overlap: {wiki_data['train'][i1]['title']}= {j} / estimate={j_e}") | ||
avg_err += err | ||
avg_err /= N_TRIALS | ||
avg_jaccard /= N_TRIALS | ||
logging.info(f"Average True Jaccard Sim: {avg_jaccard}") | ||
return avg_err | ||
|
||
|
||
num_perms = range(10, 512, 20) | ||
num_bits = [1, 2, 3, 4, 8, 12, 16, 32] | ||
bit_colors = colors = [ | ||
"#1f77b4", | ||
"#ff7f0e", | ||
"#2ca02c", | ||
"#d62728", | ||
"#9467bd", | ||
"#8c564b", | ||
"#e377c2", | ||
"#7f7f7f", | ||
] | ||
output = "b_bit_minhash_wikisimple_benchmark.png" | ||
|
||
logging.info("> Running performance tests") | ||
card = 200 | ||
perf_times = {} | ||
for b in num_bits: | ||
run_times = [run_perf(card, n, b) for n in num_perms] | ||
perf_times[b] = run_times | ||
|
||
|
||
logging.info("> Running accuracy tests") | ||
errors = {} | ||
for b in num_bits: | ||
errs = [run_acc(n, b) for n in num_perms] | ||
errors[b] = errs | ||
|
||
logging.info("> Plotting result") | ||
fig, axe = plt.subplots(1, 2, sharex=True, figsize=(10, 4)) | ||
ax = axe[1] | ||
for i, b in enumerate(num_bits): | ||
ax.plot( | ||
num_perms, perf_times[b], marker="+", color=bit_colors[i], label=f"{b} bits" | ||
) | ||
ax.set_xlabel("Number of permutation functions") | ||
ax.set_ylabel("Running time (sec)") | ||
ax.set_title("MinHash performance") | ||
ax.grid() | ||
ax.legend() | ||
ax = axe[0] | ||
for i, b in enumerate(num_bits): | ||
ax.plot(num_perms, errors[b], marker="+", color=bit_colors[i], label=f"{b} bits") | ||
ax.set_xlabel("Number of permutation functions") | ||
ax.set_ylabel("Absolute error in Jaccard estimation") | ||
ax.set_title("MinHash accuracy") | ||
ax.grid() | ||
ax.legend() | ||
|
||
plt.tight_layout() | ||
fig.savefig(output) | ||
logging.info("Plot saved to %s" % output) |