Skip to content

Commit

Permalink
Merge pull request #33 from ilyaraz/master
Browse files Browse the repository at this point in the history
Python version of the GloVe example
  • Loading branch information
Ilya Razenshteyn committed Jan 16, 2016
2 parents 94660f4 + 53f7968 commit 206b524
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/glove/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import sys
import struct
import numpy as np

matrix = []
with open('dataset/glove.840B.300d.txt', 'r') as inf:
with open('dataset/glove.840B.300d.dat', 'wb') as ouf:
counter = 0
Expand All @@ -12,5 +14,7 @@
ouf.write(struct.pack('i', len(row)))
ouf.write(struct.pack('%sf' % len(row), *row))
counter += 1
matrix.append(np.array(row, dtype=np.float32))
if counter % 10000 == 0:
sys.stdout.write('%d points processed...\n' % counter)
np.save('dataset/glove.840B.300d', np.array(matrix))
116 changes: 116 additions & 0 deletions examples/glove/glove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import print_function
import numpy as np
import falconn
import timeit
import math

if __name__ == '__main__':
dataset_file = 'dataset/glove.840B.300d.npy'
number_of_queries = 1000
# we build only 10 tables, increasing this quantity will improve the query time
# at a cost of slower preprocessing and larger memory footprint, feel free to
# play with this number
number_of_tables = 10

print('Reading the dataset')
dataset = np.load(dataset_file)
print('Done')

# It's important not to use doubles, unless they are strictly necessary.
# If your dataset consists of doubles, convert it to floats using `astype`.
assert dataset.dtype == np.float32

# Normalize all the lenghts, since we care about the cosine similarity.
print('Normalizing the dataset')
dataset /= np.linalg.norm(dataset, axis=1).reshape(-1, 1)
print('Done')

# Choose random data points to be queries.
print('Generating queries')
np.random.seed(4057218)
np.random.shuffle(dataset)
queries = dataset[len(dataset) - number_of_queries:]
dataset = dataset[:len(dataset) - number_of_queries]
print('Done')

# Perform linear scan using NumPy to get answers to the queries.
print('Solving queries using linear scan')
t1 = timeit.default_timer()
answers = []
for query in queries:
answers.append(np.dot(dataset, query).argmax())
t2 = timeit.default_timer()
print('Done')
print('Linear scan time: {} per query'.format((t2 - t1) / float(len(queries))))

# Center the dataset and the queries: this improves the performance of LSH quite a bit.
print('Centering the dataset and queries')
center = np.mean(dataset, axis=0)
dataset -= center
queries -= center
print('Done')

params_cp = falconn.LSHConstructionParameters()
params_cp.dimension = len(dataset[0])
params_cp.lsh_family = 'cross_polytope'
params_cp.distance_function = 'euclidean_squared'
params_cp.l = number_of_tables
# we set one rotation, since the data is dense enough,
# for sparse data set it to 2
params_cp.num_rotations = 1
params_cp.seed = 5721840
# we build 20-bit hashes so that each table has
# 2^20 bins; this is a good choise since 2^20 is of the same
# order of magnitude as the number of data points
falconn.compute_number_of_hash_functions(20, params_cp)

print('Constructing the LSH table')
t1 = timeit.default_timer()
table = falconn.LSHIndex(params_cp)
table.fit(dataset)
t2 = timeit.default_timer()
print('Done')
print('Construction time: {}'.format(t2 - t1))

# find the smallest number of probes to achieve accuracy 0.9
# using the binary search
print('Choosing number of probes')
number_of_probes = number_of_tables
def evaluate_number_of_probes(number_of_probes):
table.set_num_probes(number_of_probes)
score = 0
for (i, query) in enumerate(queries):
if answers[i] in table.get_candidates_with_duplicates(query):
score += 1
return float(score) / len(queries)
while True:
accuracy = evaluate_number_of_probes(number_of_probes)
print('{} -> {}'.format(number_of_probes, accuracy))
if accuracy >= 0.9:
break
number_of_probes = number_of_probes * 2
if number_of_probes > number_of_tables:
left = number_of_probes // 2
right = number_of_probes
while right - left > 1:
number_of_probes = (left + right) // 2
accuracy = evaluate_number_of_probes(number_of_probes)
print('{} -> {}'.format(number_of_probes, accuracy))
if accuracy >= 0.9:
right = number_of_probes
else:
left = number_of_probes
number_of_probes = right
print('Done')
print('{} probes'.format(number_of_probes))

# final evaluation
t1 = timeit.default_timer()
score = 0
for (i, query) in enumerate(queries):
if table.find_nearest_neighbor(query) == answers[i]:
score += 1
t2 = timeit.default_timer()

print('Query time: {}'.format((t2 - t1) / len(queries)))
print('Precision: {}'.format(float(score) / len(queries)))

0 comments on commit 206b524

Please sign in to comment.