forked from princewen/tensorflow_practice
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
529 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import os | ||
import re | ||
import numpy as np | ||
|
||
def load_task(data_dir, task_id, only_supporting=False): | ||
'''Load the nth task. There are 20 tasks in total. | ||
Returns a tuple containing the training and testing data for the task. | ||
''' | ||
assert task_id > 0 and task_id < 21 | ||
|
||
files = os.listdir(data_dir) | ||
files = [os.path.join(data_dir, f) for f in files] | ||
s = 'qa{}_'.format(task_id) | ||
train_file = [f for f in files if s in f and 'train' in f][0] | ||
test_file = [f for f in files if s in f and 'test' in f][0] | ||
train_data = get_stories(train_file, only_supporting) | ||
test_data = get_stories(test_file, only_supporting) | ||
return train_data, test_data | ||
|
||
def tokenize(sent): | ||
'''Return the tokens of a sentence including punctuation. | ||
>>> tokenize('Bob dropped the apple. Where is the apple?') | ||
['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] | ||
''' | ||
return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()] | ||
|
||
|
||
def parse_stories(lines, only_supporting=False): | ||
'''Parse stories provided in the bAbI tasks format | ||
If only_supporting is true, only the sentences that support the answer are kept. | ||
''' | ||
data = [] | ||
story = [] | ||
for line in lines: | ||
line = str.lower(line) | ||
nid, line = line.split(' ', 1) | ||
nid = int(nid) | ||
if nid == 1: | ||
story = [] | ||
if '\t' in line: # question | ||
q, a, supporting = line.split('\t') | ||
q = tokenize(q) | ||
#a = tokenize(a) | ||
# answer is one vocab word even if it's actually multiple words | ||
a = [a] | ||
substory = None | ||
|
||
# remove question marks | ||
if q[-1] == "?": | ||
q = q[:-1] | ||
|
||
if only_supporting: | ||
# Only select the related substory | ||
supporting = map(int, supporting.split()) | ||
substory = [story[i - 1] for i in supporting] | ||
else: | ||
# Provide all the substories | ||
substory = [x for x in story if x] | ||
|
||
data.append((substory, q, a)) | ||
story.append('') | ||
else: # regular sentence | ||
# remove periods | ||
sent = tokenize(line) | ||
if sent[-1] == ".": | ||
sent = sent[:-1] | ||
story.append(sent) | ||
return data | ||
|
||
|
||
def get_stories(f, only_supporting=False): | ||
'''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story. | ||
If max_length is supplied, any stories longer than max_length tokens will be discarded. | ||
''' | ||
with open(f) as f: | ||
return parse_stories(f.readlines(), only_supporting=only_supporting) | ||
|
||
def vectorize_data(data, word_idx, sentence_size, memory_size): | ||
""" | ||
Vectorize stories and queries. | ||
If a sentence length < sentence_size, the sentence will be padded with 0's. | ||
If a story length < memory_size, the story will be padded with empty memories. | ||
Empty memories are 1-D arrays of length sentence_size filled with 0's. | ||
The answer array is returned as a one-hot encoding. | ||
""" | ||
S = [] | ||
Q = [] | ||
A = [] | ||
for story, query, answer in data: | ||
ss = [] | ||
for i, sentence in enumerate(story, 1): | ||
ls = max(0, sentence_size - len(sentence)) | ||
ss.append([word_idx[w] for w in sentence] + [0] * ls) | ||
|
||
# take only the most recent sentences that fit in memory | ||
ss = ss[::-1][:memory_size][::-1] | ||
|
||
# Make the last word of each sentence the time 'word' which | ||
# corresponds to vector of lookup table | ||
for i in range(len(ss)): | ||
ss[i][-1] = len(word_idx) - memory_size - i + len(ss) | ||
|
||
# pad to memory_size | ||
lm = max(0, memory_size - len(ss)) | ||
for _ in range(lm): | ||
ss.append([0] * sentence_size) | ||
|
||
lq = max(0, sentence_size - len(query)) | ||
q = [word_idx[w] for w in query] + [0] * lq | ||
|
||
y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word | ||
for a in answer: | ||
y[word_idx[a]] = 1 | ||
|
||
S.append(ss) | ||
Q.append(q) | ||
A.append(y) | ||
return np.array(S), np.array(Q), np.array(A) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from data_utils import load_task, vectorize_data | ||
from sklearn import cross_validation, metrics | ||
from memn2n import MemN2N | ||
from itertools import chain | ||
from six.moves import range, reduce | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
|
||
tf.flags.DEFINE_float("learning_rate", 0.01, "Learning rate for SGD.") | ||
tf.flags.DEFINE_float("anneal_rate", 25, "Number of epochs between halving the learnign rate.") | ||
tf.flags.DEFINE_float("anneal_stop_epoch", 100, "Epoch number to end annealed lr schedule.") | ||
tf.flags.DEFINE_float("max_grad_norm", 40.0, "Clip gradients to this norm.") | ||
tf.flags.DEFINE_integer("evaluation_interval", 10, "Evaluate and print results every x epochs") | ||
tf.flags.DEFINE_integer("batch_size", 32, "Batch size for training.") | ||
tf.flags.DEFINE_integer("hops", 3, "Number of hops in the Memory Network.") | ||
tf.flags.DEFINE_integer("epochs", 100, "Number of epochs to train for.") | ||
tf.flags.DEFINE_integer("embedding_size", 20, "Embedding size for embedding matrices.") | ||
tf.flags.DEFINE_integer("memory_size", 50, "Maximum size of memory.") | ||
tf.flags.DEFINE_integer("task_id", 1, "bAbI task id, 1 <= id <= 20") | ||
tf.flags.DEFINE_integer("random_state", None, "Random state.") | ||
tf.flags.DEFINE_string("data_dir", "data/tasks_1-20_v1-2/en/", "Directory containing bAbI tasks") | ||
FLAGS = tf.flags.FLAGS | ||
|
||
print("Started Task:", FLAGS.task_id) | ||
|
||
# task data | ||
train, test = load_task(FLAGS.data_dir, FLAGS.task_id) | ||
data = train + test | ||
|
||
vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q + a) for s, q, a in data))) | ||
word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) | ||
|
||
max_story_size = max(map(len, (s for s, _, _ in data))) | ||
mean_story_size = int(np.mean([ len(s) for s, _, _ in data ])) | ||
sentence_size = max(map(len, chain.from_iterable(s for s, _, _ in data))) | ||
query_size = max(map(len, (q for _, q, _ in data))) | ||
memory_size = min(FLAGS.memory_size, max_story_size) | ||
|
||
# Add time words/indexes | ||
for i in range(memory_size): | ||
word_idx['time{}'.format(i+1)] = 'time{}'.format(i+1) | ||
|
||
vocab_size = len(word_idx) + 1 # +1 for nil word | ||
sentence_size = max(query_size, sentence_size) # for the position | ||
sentence_size += 1 # +1 for time words | ||
|
||
print("Longest sentence length", sentence_size) | ||
print("Longest story length", max_story_size) | ||
print("Average story length", mean_story_size) | ||
|
||
# train/validation/test sets | ||
S, Q, A = vectorize_data(train, word_idx, sentence_size, memory_size) | ||
trainS, valS, trainQ, valQ, trainA, valA = cross_validation.train_test_split(S, Q, A, test_size=.1, random_state=FLAGS.random_state) | ||
testS, testQ, testA = vectorize_data(test, word_idx, sentence_size, memory_size) | ||
|
||
print(testS[0]) | ||
|
||
print("Training set shape", trainS.shape) | ||
|
||
# params | ||
n_train = trainS.shape[0] | ||
n_test = testS.shape[0] | ||
n_val = valS.shape[0] | ||
|
||
print("Training Size", n_train) | ||
print("Validation Size", n_val) | ||
print("Testing Size", n_test) | ||
|
||
train_labels = np.argmax(trainA, axis=1) | ||
test_labels = np.argmax(testA, axis=1) | ||
val_labels = np.argmax(valA, axis=1) | ||
|
||
tf.set_random_seed(FLAGS.random_state) | ||
batch_size = FLAGS.batch_size | ||
|
||
batches = zip(range(0, n_train-batch_size, batch_size), range(batch_size, n_train, batch_size)) | ||
batches = [(start, end) for start, end in batches] | ||
|
||
with tf.Session() as sess: | ||
model = MemN2N(batch_size, vocab_size, sentence_size, memory_size, FLAGS.embedding_size, session=sess, | ||
hops=FLAGS.hops, max_grad_norm=FLAGS.max_grad_norm) | ||
for t in range(1, FLAGS.epochs+1): | ||
# Stepped learning rate | ||
if t - 1 <= FLAGS.anneal_stop_epoch: | ||
anneal = 2.0 ** ((t - 1) // FLAGS.anneal_rate) | ||
else: | ||
anneal = 2.0 ** (FLAGS.anneal_stop_epoch // FLAGS.anneal_rate) | ||
lr = FLAGS.learning_rate / anneal | ||
|
||
np.random.shuffle(batches) | ||
total_cost = 0.0 | ||
for start, end in batches: | ||
s = trainS[start:end] | ||
q = trainQ[start:end] | ||
a = trainA[start:end] | ||
cost_t = model.batch_fit(s, q, a, lr) | ||
total_cost += cost_t | ||
|
||
if t % FLAGS.evaluation_interval == 0: | ||
train_preds = [] | ||
for start in range(0, n_train, batch_size): | ||
end = start + batch_size | ||
s = trainS[start:end] | ||
q = trainQ[start:end] | ||
pred = model.predict(s, q) | ||
train_preds += list(pred) | ||
|
||
val_preds = model.predict(valS, valQ) | ||
train_acc = metrics.accuracy_score(np.array(train_preds), train_labels) | ||
val_acc = metrics.accuracy_score(val_preds, val_labels) | ||
|
||
print('-----------------------') | ||
print('Epoch', t) | ||
print('Total Cost:', total_cost) | ||
print('Training Accuracy:', train_acc) | ||
print('Validation Accuracy:', val_acc) | ||
print('-----------------------') | ||
|
||
test_preds = model.predict(testS, testQ) | ||
test_acc = metrics.accuracy_score(test_preds, test_labels) | ||
print("Testing Accuracy:", test_acc) |
Oops, something went wrong.