-
Notifications
You must be signed in to change notification settings - Fork 38
/
extract.py
95 lines (81 loc) · 3.54 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import sys, os
from process import parse_sentence
from mapper import Map, deduplication
from transformers import AutoTokenizer, BertModel, GPT2Model
import argparse
import en_core_web_md
from tqdm import tqdm
import json
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(description='Process lines of text corpus into knowledgraph')
parser.add_argument('input_filename', type=str, help='text file as input')
parser.add_argument('output_filename', type=str, help='output text file')
parser.add_argument('--language_model',default='bert-base-cased',
choices=[ 'bert-large-uncased', 'bert-large-cased', 'bert-base-uncased', 'bert-base-cased', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'],
help='which language model to use')
parser.add_argument('--use_cuda', default=True,
type=str2bool, nargs='?',
help="Use cuda?")
parser.add_argument('--include_text_output', default=False,
type=str2bool, nargs='?',
help="Include original sentence in output")
parser.add_argument('--threshold', default=0.003,
type=float, help="Any attention score lower than this is removed")
args = parser.parse_args()
use_cuda = args.use_cuda
nlp = en_core_web_md.load()
'''Create
Tested language model:
1. bert-base-cased
2. gpt2-medium
Basically any model that belongs to this family should work
'''
language_model = args.language_model
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained(language_model)
if 'gpt2' in language_model:
encoder = GPT2Model.from_pretrained(language_model)
else:
encoder = BertModel.from_pretrained(language_model)
encoder.eval()
if use_cuda:
encoder = encoder.cuda()
input_filename = args.input_filename
output_filename = args.output_filename
include_sentence = args.include_text_output
with open(input_filename, 'r') as f, open(output_filename, 'w') as g:
for idx, line in enumerate(tqdm(f)):
sentence = line.strip()
if len(sentence):
valid_triplets = []
for sent in nlp(sentence).sents:
# Match
for triplets in parse_sentence(sent.text, tokenizer, encoder, nlp, use_cuda=use_cuda):
valid_triplets.append(triplets)
if len(valid_triplets) > 0:
# Map
mapped_triplets = []
for triplet in valid_triplets:
head = triplet['h']
tail = triplet['t']
relations = triplet['r']
conf = triplet['c']
if conf < args.threshold:
continue
mapped_triplet = Map(head, relations, tail)
if 'h' in mapped_triplet:
mapped_triplet['c'] = conf
mapped_triplets.append(mapped_triplet)
output = { 'line': idx, 'tri': deduplication(mapped_triplets) }
if include_sentence:
output['sent'] = sentence
if len(output['tri']) > 0:
g.write(json.dumps( output )+'\n')