-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataloading.py
249 lines (197 loc) · 7.66 KB
/
dataloading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import numpy as np
from scipy.spatial import distance
# Default word tokens
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
# handle dull_responses now
DULL_RESPONSES = ["I do not know what you are talking about.", "I do not know.", "You do not know.",
"You know what I mean.", "I know what you mean.", "You know what I am saying.", "You do not know anything."]
class Voc:
def __init__(self, name):
self.name = name
self.trimmed = False
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD",
SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count SOS, EOS, PAD
def add_sentence(self, sentence):
for word in sentence.split(' '):
self.add_word(word)
def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.num_words
self.word2count[word] = 1
self.index2word[self.num_words] = word
self.num_words += 1
else:
self.word2count[word] += 1
# Remove words below a certain count threshold
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print('keep_words {} / {} = {:.4f}'.format(
len(keep_words), len(self.word2index), len(
keep_words) / len(self.word2index)
))
# Reinitialize dictionaries
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD",
SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count default tokens
for word in keep_words:
self.add_word(word)
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# Lowercase, trim, and remove non-letter characters
def normalize_string(s):
s = unicode_to_ascii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s = re.sub(r"\s+", r" ", s).strip()
return s
# Read query/response pairs and return a voc object
def read_vocs(datafile, corpus_name):
print("Reading lines...")
# Read the file and split into lines
lines = open(datafile, encoding='utf-8').\
read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
voc = Voc(corpus_name)
return voc, pairs
# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filter_pair(p, max_length=15):
# Input sequences need to preserve the last word for EOS token
return len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length
# Filter pairs using filterPair condition
def filter_pairs(pairs):
return [pair for pair in pairs if filter_pair(pair)]
# Using the functions defined above, return a populated voc object and pairs list
def load_prepare_data(corpus, corpus_name, datafile, save_dir):
print("Start preparing training data ...")
voc, pairs = read_vocs(datafile, corpus_name)
print("Read {!s} sentence pairs".format(len(pairs)))
pairs = filter_pairs(pairs)
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
print("Counting words...")
for d in DULL_RESPONSES:
voc.add_sentence(d)
for pair in pairs:
voc.add_sentence(pair[0])
voc.add_sentence(pair[1])
print("Counted words:", voc.num_words)
return voc, pairs
def trim_rare_words(voc, pairs, min_count=3):
# Trim words used under the MIN_COUNT from the voc
voc.trim(min_count)
# Filter out pairs with trimmed words
keep_pairs = []
for pair in pairs:
input_sentence = pair[0]
output_sentence = pair[1]
keep_input = True
keep_output = True
# Check input sentence
for word in input_sentence.split(' '):
if word not in voc.word2index:
keep_input = False
break
# Check output sentence
for word in output_sentence.split(' '):
if word not in voc.word2index:
keep_output = False
break
# Only keep pairs that do not contain trimmed word(s) in their input or output sentence
if keep_input and keep_output:
keep_pairs.append(pair)
print("Trimmed from {} pairs to {}, {:.4f} of total".format(
len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
return keep_pairs
def indexes_from_sentence(voc, sentence):
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
def zero_padding(l, fillvalue=PAD_token):
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
def binary_matrix(l, value=PAD_token):
m = []
for i, seq in enumerate(l):
m.append([])
for token in seq:
if token == PAD_token:
m[i].append(0)
else:
m[i].append(1)
return m
# Returns padded input sequence tensor and lengths
def input_var(l, voc):
indexes_batch = [indexes_from_sentence(voc, sentence) for sentence in l]
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
pad_list = zero_padding(indexes_batch)
pad_var = torch.LongTensor(pad_list)
return pad_var, lengths
# Returns padded target sequence tensor, padding mask, and max target length
def output_var(l, voc):
indexes_batch = [indexes_from_sentence(voc, sentence) for sentence in l]
max_target_len = max([len(indexes) for indexes in indexes_batch])
pad_list = zero_padding(indexes_batch)
mask = binary_matrix(pad_list)
mask = torch.BoolTensor(mask)
pad_var = torch.LongTensor(pad_list)
return pad_var, mask, max_target_len
# Returns all items for a given batch of pairs
def batch_2_train_data(voc, pair_batch):
pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
input_batch, output_batch = [], []
for pair in pair_batch:
input_batch.append(pair[0])
output_batch.append(pair[1])
inp, lengths = input_var(input_batch, voc)
output, mask, max_target_len = output_var(output_batch, voc)
return inp, lengths, output, mask, max_target_len
if __name__ == '__main__':
corpus_name = "train"
corpus = os.path.join("data", corpus_name)
datafile = os.path.join(corpus, "formatted_dialogues_train.txt")
# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = load_prepare_data(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
print(pair)
pairs = trim_rare_words(voc, pairs, min_count=3)
# Example for validation
small_batch_size = 5
batches = batch_2_train_data(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)