-
Notifications
You must be signed in to change notification settings - Fork 0
/
aux_functions.py
60 lines (48 loc) · 1.93 KB
/
aux_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import csv
import os
import argparse
from gensim.models import KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec
def paradigmatic_neighbours(word, model_under_evaluation):
return [word for word, _ in model_under_evaluation.most_similar(word, topn=30)]
def read_csv(filename='ParaLex.csv'):
output = []
with open(filename) as f:
csv_reader = csv.reader(f)
for line in csv_reader:
output.append([term.strip() for term in line if term != ''])
return output[1:]
def load_language_specific_data(language):
data_sheet = read_csv()
if len(language) > 3: # language name
selection_column = 1
else: # language code
selection_column = 0
rows = [row for row in data_sheet if row[selection_column].upper() == language.upper()]
if len(rows) == 0:
raise Exception('Language not found')
output = {}
for row in rows:
output[row[2]] = row[3:]
return output
def load_model(file, format_flag):
if format_flag == 'binary':
model = KeyedVectors.load_word2vec_format(file, binary=True, encoding='utf8')
elif format_flag == 'glove':
w2v_file = file[:-4] + '.w2v.txt'
if w2v_file in os.listdir():
pass
else:
glove2word2vec(file, w2v_file)
model = KeyedVectors.load_word2vec_format(w2v_file, binary=False, encoding='utf8')
else:
model = KeyedVectors.load_word2vec_format(file, binary=False, encoding='utf8')
return model
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("path_to_model", help="the path of the model you wish to evaluate")
parser.add_argument("-f", "--format", help="the format of the model you are evaluating")
parser.add_argument("language", help="both name and ISO language code accepted")
args = parser.parse_args()
return args.path_to_model, args.format, args.language
print(parse_arguments())