-
Notifications
You must be signed in to change notification settings - Fork 194
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from ilyaraz/master
Python version of the GloVe example
- Loading branch information
Showing
2 changed files
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from __future__ import print_function | ||
import numpy as np | ||
import falconn | ||
import timeit | ||
import math | ||
|
||
if __name__ == '__main__': | ||
dataset_file = 'dataset/glove.840B.300d.npy' | ||
number_of_queries = 1000 | ||
# we build only 10 tables, increasing this quantity will improve the query time | ||
# at a cost of slower preprocessing and larger memory footprint, feel free to | ||
# play with this number | ||
number_of_tables = 10 | ||
|
||
print('Reading the dataset') | ||
dataset = np.load(dataset_file) | ||
print('Done') | ||
|
||
# It's important not to use doubles, unless they are strictly necessary. | ||
# If your dataset consists of doubles, convert it to floats using `astype`. | ||
assert dataset.dtype == np.float32 | ||
|
||
# Normalize all the lenghts, since we care about the cosine similarity. | ||
print('Normalizing the dataset') | ||
dataset /= np.linalg.norm(dataset, axis=1).reshape(-1, 1) | ||
print('Done') | ||
|
||
# Choose random data points to be queries. | ||
print('Generating queries') | ||
np.random.seed(4057218) | ||
np.random.shuffle(dataset) | ||
queries = dataset[len(dataset) - number_of_queries:] | ||
dataset = dataset[:len(dataset) - number_of_queries] | ||
print('Done') | ||
|
||
# Perform linear scan using NumPy to get answers to the queries. | ||
print('Solving queries using linear scan') | ||
t1 = timeit.default_timer() | ||
answers = [] | ||
for query in queries: | ||
answers.append(np.dot(dataset, query).argmax()) | ||
t2 = timeit.default_timer() | ||
print('Done') | ||
print('Linear scan time: {} per query'.format((t2 - t1) / float(len(queries)))) | ||
|
||
# Center the dataset and the queries: this improves the performance of LSH quite a bit. | ||
print('Centering the dataset and queries') | ||
center = np.mean(dataset, axis=0) | ||
dataset -= center | ||
queries -= center | ||
print('Done') | ||
|
||
params_cp = falconn.LSHConstructionParameters() | ||
params_cp.dimension = len(dataset[0]) | ||
params_cp.lsh_family = 'cross_polytope' | ||
params_cp.distance_function = 'euclidean_squared' | ||
params_cp.l = number_of_tables | ||
# we set one rotation, since the data is dense enough, | ||
# for sparse data set it to 2 | ||
params_cp.num_rotations = 1 | ||
params_cp.seed = 5721840 | ||
# we build 20-bit hashes so that each table has | ||
# 2^20 bins; this is a good choise since 2^20 is of the same | ||
# order of magnitude as the number of data points | ||
falconn.compute_number_of_hash_functions(20, params_cp) | ||
|
||
print('Constructing the LSH table') | ||
t1 = timeit.default_timer() | ||
table = falconn.LSHIndex(params_cp) | ||
table.fit(dataset) | ||
t2 = timeit.default_timer() | ||
print('Done') | ||
print('Construction time: {}'.format(t2 - t1)) | ||
|
||
# find the smallest number of probes to achieve accuracy 0.9 | ||
# using the binary search | ||
print('Choosing number of probes') | ||
number_of_probes = number_of_tables | ||
def evaluate_number_of_probes(number_of_probes): | ||
table.set_num_probes(number_of_probes) | ||
score = 0 | ||
for (i, query) in enumerate(queries): | ||
if answers[i] in table.get_candidates_with_duplicates(query): | ||
score += 1 | ||
return float(score) / len(queries) | ||
while True: | ||
accuracy = evaluate_number_of_probes(number_of_probes) | ||
print('{} -> {}'.format(number_of_probes, accuracy)) | ||
if accuracy >= 0.9: | ||
break | ||
number_of_probes = number_of_probes * 2 | ||
if number_of_probes > number_of_tables: | ||
left = number_of_probes // 2 | ||
right = number_of_probes | ||
while right - left > 1: | ||
number_of_probes = (left + right) // 2 | ||
accuracy = evaluate_number_of_probes(number_of_probes) | ||
print('{} -> {}'.format(number_of_probes, accuracy)) | ||
if accuracy >= 0.9: | ||
right = number_of_probes | ||
else: | ||
left = number_of_probes | ||
number_of_probes = right | ||
print('Done') | ||
print('{} probes'.format(number_of_probes)) | ||
|
||
# final evaluation | ||
t1 = timeit.default_timer() | ||
score = 0 | ||
for (i, query) in enumerate(queries): | ||
if table.find_nearest_neighbor(query) == answers[i]: | ||
score += 1 | ||
t2 = timeit.default_timer() | ||
|
||
print('Query time: {}'.format((t2 - t1) / len(queries))) | ||
print('Precision: {}'.format(float(score) / len(queries))) |