-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LSH Ensemble Optimal Partitioning (#79)
Use optimal partitioning instead of equi-depth so the index can be adaptive to any set size distribution.
- Loading branch information
Showing
47 changed files
with
975 additions
and
577 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,192 @@ | ||
""" | ||
Benchmark dataset from: | ||
https://github.com/ekzhu/set-similarity-search-benchmark. | ||
Use "Canada US and UK Open Data": | ||
Indexed sets: canada_us_uk_opendata.inp.gz | ||
Query sets (10 stratified samples from 10 percentile intervals): | ||
Size from 10 - 1k: canada_us_uk_opendata_queries_1k.inp.gz | ||
Size from 10 - 10k: canada_us_uk_opendata_queries_10k.inp.gz | ||
Size from 10 - 100k: canada_us_uk_opendata_queries_100k.inp.gz | ||
""" | ||
import time, argparse, sys, json | ||
from hashlib import sha1 | ||
import numpy as np | ||
import nltk | ||
import scipy.stats | ||
import random | ||
import collections | ||
import gzip | ||
import random | ||
import os | ||
import pickle | ||
import pandas as pd | ||
from SetSimilaritySearch import SearchIndex | ||
|
||
from datasketch import MinHashLSHEnsemble, MinHash | ||
from lshforest_benchmark import bootstrap_data | ||
|
||
def benchmark_lshensemble(threshold, num_perm, num_part, l, index_data, query_data): | ||
|
||
def bootstrap_sets(sets_file, sample_ratio, num_perms, skip=1): | ||
print("Creating sets...") | ||
sets = collections.deque([]) | ||
random.seed(41) | ||
with gzip.open(sets_file, "rt") as f: | ||
for i, line in enumerate(f): | ||
if i < skip: | ||
# Skip lines | ||
continue | ||
if random.random() > sample_ratio: | ||
continue | ||
s = np.array([int(d) for d in \ | ||
line.strip().split("\t")[1].split(",")]) | ||
sets.append(s) | ||
sys.stdout.write("\rRead {} sets".format(len(sets))) | ||
sys.stdout.write("\n") | ||
sets = list(sets) | ||
keys = list(range(len(sets))) | ||
print("Creating MinHash...") | ||
minhashes = dict() | ||
for num_perm in num_perms: | ||
print("Using num_parm = {}".format(num_perm)) | ||
ms = [] | ||
for s in sets: | ||
m = MinHash(num_perm) | ||
for word in s: | ||
m.update(str(word).encode("utf8")) | ||
ms.append(m) | ||
sys.stdout.write("\rMinhashed {} sets".format(len(ms))) | ||
sys.stdout.write("\n") | ||
minhashes[num_perm] = ms | ||
return (minhashes, sets, keys) | ||
|
||
|
||
def benchmark_lshensemble(threshold, num_perm, num_part, m, index_data, | ||
query_data): | ||
print("Building LSH Ensemble index") | ||
lsh = MinHashLSHEnsemble(threshold=threshold, num_perm=num_perm, num_part=num_part, l=l) | ||
lsh.index((key, minhash, len(set)) | ||
for key, minhash, set in \ | ||
zip(index_data.keys, index_data.minhashes[num_perm], index_data.sets)) | ||
(minhashes, indexed_sets, keys) = index_data | ||
lsh = MinHashLSHEnsemble(threshold=threshold, num_perm=num_perm, | ||
num_part=num_part, m=m) | ||
lsh.index((key, minhash, len(set)) | ||
for key, minhash, set in \ | ||
zip(keys, minhashes[num_perm], indexed_sets)) | ||
print("Querying") | ||
(minhashes, sets, keys) = query_data | ||
times = [] | ||
results = [] | ||
for qs, minhash in zip(query_data.sets, query_data.minhashes[num_perm]): | ||
start = time.clock() | ||
for qs, minhash in zip(sets, minhashes[num_perm]): | ||
start = time.perf_counter() | ||
result = list(lsh.query(minhash, len(qs))) | ||
duration = time.clock() - start | ||
duration = time.perf_counter() - start | ||
times.append(duration) | ||
results.append(sorted([[key, _compute_containment(qs, index_data.sets[key])] | ||
for key in result], | ||
key=lambda x : x[1], reverse=True)) | ||
results.append(result) | ||
# results.append(sorted([[key, _compute_containment(qs, indexed_sets[key])] | ||
# for key in result], | ||
# key=lambda x : x[1], reverse=True)) | ||
sys.stdout.write("\rQueried {} sets".format(len(results))) | ||
sys.stdout.write("\n") | ||
return times, results | ||
|
||
|
||
def benchmark_ground_truth(threshold, index_data, query_data): | ||
def benchmark_ground_truth(threshold, index, query_data): | ||
(_, query_sets, _) = query_data | ||
times = [] | ||
results = [] | ||
for q in query_data.sets: | ||
start = time.clock() | ||
result = [key for key, a in zip(index_data.keys, index_data.sets) | ||
if _compute_containment(q, a) >= threshold] | ||
duration = time.clock() - start | ||
for q in query_sets: | ||
start = time.perf_counter() | ||
result = [key for key, _ in index.query(q)] | ||
duration = time.perf_counter() - start | ||
times.append(duration) | ||
results.append(sorted([[key, _compute_containment(q, index_data.sets[key])] | ||
for key in result], | ||
key=lambda x : x[1], reverse=True)) | ||
results.append(result) | ||
sys.stdout.write("\rQueried {} sets".format(len(results))) | ||
sys.stdout.write("\n") | ||
return times, results | ||
|
||
|
||
def _compute_containment(x, y): | ||
if len(x) == 0 or len(y) == 0: | ||
return 0.0 | ||
intersection = 0 | ||
for w in x: | ||
if w in y: | ||
intersection += 1 | ||
intersection = len(np.intersect1d(x, y, assume_unique=True)) | ||
return float(intersection) / float(len(x)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--output", type=str, default="lshensemble_benchmark.json") | ||
parser = argparse.ArgumentParser( | ||
description="Run LSH Ensemble benchmark using data sets obtained " | ||
"from https://github.com/ekzhu/set-similarity-search-benchmark.") | ||
parser.add_argument("--indexed-sets", type=str, required=True, | ||
help="Input indexed set file (gzipped), each line is a set: " | ||
"<set_size> <1>,<2>,<3>..., where each <?> is an element.") | ||
parser.add_argument("--query-sets", type=str, required=True, | ||
help="Input query set file (gzipped), each line is a set: " | ||
"<set_size> <1>,<2>,<3>..., where each <?> is an element.") | ||
parser.add_argument("--query-results", type=str, | ||
default="lshensemble_benchmark_query_results.csv") | ||
parser.add_argument("--ground-truth-results", type=str, | ||
default="lshensemble_benchmark_ground_truth_results.csv") | ||
parser.add_argument("--skip-ground-truth", action="store_true") | ||
args = parser.parse_args(sys.argv[1:]) | ||
|
||
threshold = 0.5 | ||
num_perms = [32, 64, 96, 128, 160, 192, 224, 256] | ||
num_part = 16 | ||
l = 8 | ||
output = {"threshold" : threshold, | ||
"num_perms" : num_perms, | ||
"num_part" : 16, | ||
"l" : l, | ||
"lsh_times" : [], "lsh_results" : [], | ||
"ground_truth_times" : None, "ground_truth_results" : None} | ||
|
||
population_size = 500 | ||
|
||
class zipfian: | ||
def __init__(self): | ||
self.rv = scipy.stats.zipf(1.25) | ||
def rvs(self): | ||
x = int(self.rv.rvs()) | ||
if x > population_size: | ||
return population_size | ||
return x | ||
|
||
index_data, query_data = bootstrap_data(num_perms, 100, population_size, zipfian()) | ||
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] | ||
num_parts = [8, 16, 32] | ||
#num_perms = [32, 64, 96, 128, 160, 192, 224, 256] | ||
num_perms = [256,] | ||
m = 8 | ||
|
||
index_data, query_data = None, None | ||
index_data_cache = "{}.pickle".format(args.indexed_sets) | ||
query_data_cache = "{}.pickle".format(args.query_sets) | ||
if os.path.exists(index_data_cache): | ||
print("Using cached indexed sets {}".format(index_data_cache)) | ||
with open(index_data_cache, "rb") as d: | ||
index_data = pickle.load(d) | ||
else: | ||
print("Using indexed sets {}".format(args.indexed_sets)) | ||
index_data = bootstrap_sets(args.indexed_sets, 0.1, num_perms) | ||
with open(index_data_cache, "wb") as d: | ||
pickle.dump(index_data, d) | ||
if os.path.exists(query_data_cache): | ||
print("Using cached query sets {}".format(query_data_cache)) | ||
with open(query_data_cache, "rb") as d: | ||
query_data = pickle.load(d) | ||
else: | ||
print("Using query sets {}".format(args.query_sets)) | ||
query_data = bootstrap_sets(args.query_sets, 1.0, num_perms, skip=0) | ||
with open(query_data_cache, "wb") as d: | ||
pickle.dump(query_data, d) | ||
|
||
if not args.skip_ground_truth: | ||
rows = [] | ||
# Build search index separately, only works for containment. | ||
print("Building search index...") | ||
index = SearchIndex(index_data[1], similarity_func_name="containment", | ||
similarity_threshold=0.1) | ||
for threshold in thresholds: | ||
index.similarity_threshold = threshold | ||
print("Running ground truth benchmark threshold = {}".format(threshold)) | ||
ground_truth_times, ground_truth_results = \ | ||
benchmark_ground_truth(threshold, index, query_data) | ||
for t, r, query_set, query_key in zip(ground_truth_times, | ||
ground_truth_results, query_data[1], query_data[2]): | ||
rows.append((query_key, len(query_set), threshold, t, | ||
",".join(str(k) for k in r))) | ||
df_groundtruth = pd.DataFrame.from_records(rows, | ||
columns=["query_key", "query_size", "threshold", | ||
"query_time", "results"]) | ||
df_groundtruth.to_csv(args.ground_truth_results) | ||
|
||
rows = [] | ||
for threshold in thresholds: | ||
for num_part in num_parts: | ||
for num_perm in num_perms: | ||
print("Running LSH Ensemble benchmark " | ||
"threshold = {}, num_part = {}, num_perm = {}".format( | ||
threshold, num_part, num_perm)) | ||
lsh_times, lsh_results = benchmark_lshensemble( | ||
threshold, num_perm, num_part, m, index_data, query_data) | ||
for t, r, query_set, query_key in zip(lsh_times, lsh_results, | ||
query_data[1], query_data[2]): | ||
rows.append((query_key, len(query_set), threshold, | ||
num_part, num_perm, t, ",".join(str(k) for k in r))) | ||
df = pd.DataFrame.from_records(rows, | ||
columns=["query_key", "query_size", "threshold", "num_part", | ||
"num_perm", "query_time", "results"]) | ||
df.to_csv(args.query_results) | ||
|
||
for num_perm in num_perms: | ||
print("Use num_perm = %d" % num_perm) | ||
result = {} | ||
print("Running LSH Ensemble benchmark") | ||
lsh_times, lsh_results = benchmark_lshensemble(threshold, num_perm, num_part, l, index_data, query_data) | ||
output["lsh_times"].append(lsh_times) | ||
output["lsh_results"].append(lsh_results) | ||
|
||
print("Running ground truth benchmark") | ||
output["ground_truth_times"], output["ground_truth_results"] =\ | ||
benchmark_ground_truth(threshold, index_data, query_data) | ||
|
||
average_cardinality = np.mean([len(s) for s in | ||
index_data.sets + query_data.sets]) | ||
print("Average cardinality is", average_cardinality) | ||
|
||
with open(args.output, 'w') as f: | ||
json.dump(output, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,84 @@ | ||
import json, sys, argparse | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib | ||
matplotlib.use("Agg") | ||
import matplotlib.pyplot as plt | ||
from lsh_benchmark_plot import get_precision_recall, fscore, average_fscore | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("benchmark_output") | ||
parser.add_argument("query_results") | ||
parser.add_argument("ground_truth_results") | ||
args = parser.parse_args(sys.argv[1:]) | ||
df = pd.read_csv(args.query_results, | ||
converters={"results": lambda x: x.split(",")}) | ||
df_groundtruth = pd.read_csv(args.ground_truth_results, | ||
converters={"results": lambda x: x.split(",")}) | ||
df_groundtruth["has_result"] = [len(r) > 0 | ||
for r in df_groundtruth["results"]] | ||
df_groundtruth = df_groundtruth[df_groundtruth["has_result"]] | ||
df = pd.merge(df, df_groundtruth, on=["query_key", "threshold"], | ||
suffixes=("", "_ground_truth")) | ||
prs = [get_precision_recall(result, ground_truth) | ||
for result, ground_truth in \ | ||
zip(df["results"], df["results_ground_truth"])] | ||
df["precision"] = [p for p, _ in prs] | ||
df["recall"] = [r for _, r in prs] | ||
df["fscore"] = [fscore(*pr) for pr in prs] | ||
|
||
with open(args.benchmark_output) as f: | ||
benchmark = json.load(f) | ||
thresholds = sorted(list(set(df["threshold"]))) | ||
num_perms = sorted(list(set(df["num_perm"]))) | ||
num_parts = sorted(list(set(df["num_part"]))) | ||
|
||
num_perms = benchmark["num_perms"] | ||
lsh_times = benchmark["lsh_times"] | ||
ground_truth_results = [[x[0] for x in r] for r in benchmark["ground_truth_results"]] | ||
lsh_fscores = [] | ||
for results in benchmark["lsh_results"]: | ||
query_results = [[x[0] for x in r] for r in results] | ||
lsh_fscores.append(average_fscore(query_results, ground_truth_results)) | ||
|
||
lsh_times = np.array([np.percentile(ts, 90) | ||
for ts in lsh_times])*1000 | ||
|
||
fig, axes = plt.subplots(1, 2, figsize=(5*2, 4.5), sharex=True) | ||
# Plot query fscore vs. num perm | ||
axes[0].plot(num_perms, lsh_fscores, marker="+", label="LSH Ensemble") | ||
axes[0].set_ylabel("Average F-Score") | ||
axes[0].set_xlabel("# of Permmutation Functions") | ||
axes[0].grid() | ||
# Plot query time vs. num perm | ||
axes[1].plot(num_perms, lsh_times, marker="+", label="LSH Ensemble") | ||
axes[1].set_xlabel("# of Permutation Functions") | ||
axes[1].set_ylabel("90 Percentile Query Time (ms)") | ||
axes[1].grid() | ||
axes[1].legend(loc="lower right") | ||
plt.tight_layout() | ||
fig.savefig("lshensemble_benchmark.png", pad_inches=0.05, bbox_inches="tight") | ||
for i, num_perm in enumerate(num_perms): | ||
# Plot precisions | ||
for j, num_part in enumerate(num_parts): | ||
sub = df[(df["num_part"] == num_part) & (df["num_perm"] == num_perm)].\ | ||
groupby("threshold") | ||
precisions = sub["precision"].mean() | ||
plt.plot(thresholds, precisions, label="num_part = {}".format(num_part)) | ||
plt.ylim(0.0, 1.0) | ||
plt.xlabel("Thresholds") | ||
plt.ylabel("Average Precisions") | ||
plt.grid() | ||
plt.legend() | ||
plt.savefig("lshensemble_num_perm_{}_precision.png".format(num_perm)) | ||
plt.close() | ||
# Plot recalls | ||
for j, num_part in enumerate(num_parts): | ||
sub = df[(df["num_part"] == num_part) & (df["num_perm"] == num_perm)].\ | ||
groupby("threshold") | ||
recalls = sub["recall"].mean() | ||
plt.plot(thresholds, recalls, label="num_part = {}".format(num_part)) | ||
plt.ylim(0.0, 1.0) | ||
plt.xlabel("Thresholds") | ||
plt.ylabel("Average Recalls") | ||
plt.grid() | ||
plt.legend() | ||
plt.savefig("lshensemble_num_perm_{}_recall.png".format(num_perm)) | ||
plt.close() | ||
# Plot fscores. | ||
for j, num_part in enumerate(num_parts): | ||
sub = df[(df["num_part"] == num_part) & (df["num_perm"] == num_perm)].\ | ||
groupby("threshold") | ||
fscores = sub["fscore"].mean() | ||
plt.plot(thresholds, fscores, label="num_part = {}".format(num_part)) | ||
plt.ylim(0.0, 1.0) | ||
plt.xlabel("Thresholds") | ||
plt.ylabel("Average F-Scores") | ||
plt.grid() | ||
plt.legend() | ||
plt.savefig("lshensemble_num_perm_{}_fscore.png".format(num_perm)) | ||
plt.close() | ||
# Plot query time. | ||
times = [] | ||
for num_part in num_parts: | ||
t = np.percentile(df[df["num_part"] == num_part]["query_time"], 90) | ||
times.append(t * 1000.0) | ||
plt.bar(num_parts, times) | ||
plt.xlabel("Number of Partitions") | ||
plt.ylabel("90 Percentile Query Time (ms)") | ||
plt.grid() | ||
plt.savefig("lshensemble_num_perm_{}_query_time.png".format(num_perm)) | ||
plt.close() |
Oops, something went wrong.