-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
127 lines (108 loc) · 4.81 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -*- coding: utf-8 -*-
import tensorflow as tf
import vis_lstm_model
import data_loader
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--num_lstm_layers', type=int, default=2,
help='num_lstm_layers')
parser.add_argument('--fc7_feature_length', type=int, default=4096,
help='fc7_feature_length')
parser.add_argument('--rnn_size', type=int, default=512,
help='rnn_size')
parser.add_argument('--embedding_size', type=int, default=512,
help='embedding_size')
parser.add_argument('--word_emb_dropout', type=float, default=0.5,
help='word_emb_dropout')
parser.add_argument('--image_dropout', type=float, default=0.5,
help='image_dropout')
parser.add_argument('--data_dir', type=str, default='Data',
help='Data directory')
parser.add_argument('--batch_size', type=int, default=200,
help='Batch Size')
parser.add_argument('--learning_rate', type=float, default=0.001,
help='Batch Size')
parser.add_argument('--epochs', type=int, default=200,
help='Expochs')
parser.add_argument('--debug', type=bool, default=False,
help='Debug')
parser.add_argument('--resume_model', type=str, default=None,
help='Trained Model Path')
args = parser.parse_args()
print "Reading QA DATA"
qa_data = data_loader.load_questions_answers(args.data_dir)
print "Reading fc7 features"
fc7_features, image_id_list = data_loader.load_fc7_features(args.data_dir, 'train')
print "FC7 features", fc7_features.shape
print "image_id_list", image_id_list.shape
# image_id_list是存在image_id的list
image_id_map = {}
for i in xrange(len(image_id_list)):
image_id_map[ image_id_list[i] ] = i
ans_map = { qa_data['answer_vocab'][ans] : ans for ans in qa_data['answer_vocab']}
model_options = {
'num_lstm_layers' : args.num_lstm_layers,
'rnn_size' : args.rnn_size,
'embedding_size' : args.embedding_size,
'word_emb_dropout' : args.word_emb_dropout,
'image_dropout' : args.image_dropout,
'fc7_feature_length' : args.fc7_feature_length,
'lstm_steps' : qa_data['max_question_length'] + 1,
'q_vocab_size' : len(qa_data['question_vocab']),
'ans_vocab_size' : len(qa_data['answer_vocab'])
}
model = vis_lstm_model.Vis_lstm_model(model_options)
input_tensors, t_loss, t_accuracy, t_p = model.build_model()
train_op = tf.train.AdamOptimizer(args.learning_rate).minimize(t_loss)
sess = tf.InteractiveSession()
tf.initialize_all_variables().run()
saver = tf.train.Saver()
if args.resume_model:
saver.restore(sess, args.resume_model)
for i in xrange(args.epochs):
batch_no = 0
while (batch_no*args.batch_size) < len(qa_data['training']):
sentence, answer, fc7 = get_training_batch(batch_no, args.batch_size, fc7_features, image_id_map, qa_data, 'train')
_, loss_value, accuracy, pred = sess.run([train_op, t_loss, t_accuracy, t_p],
feed_dict={
input_tensors['fc7']:fc7,
input_tensors['sentence']:sentence,
input_tensors['answer']:answer
}
)
batch_no += 1
if args.debug:
for idx, p in enumerate(pred):
print ans_map[p], ans_map[ np.argmax(answer[idx])]
print "Loss", loss_value, batch_no, i
print "Accuracy", accuracy
print "---------------"
else:
print "Loss", loss_value, batch_no, i
print "Training Accuracy", accuracy
save_path = saver.save(sess, "Data/Models/model{}.ckpt".format(i))
# 图片的数量比问题的数量少
def get_training_batch(batch_no, batch_size, fc7_features, image_id_map, qa_data, split):
qa = None
if split == 'train':
qa = qa_data['training']
else:
qa = qa_data['validation']
si = (batch_no * batch_size)%len(qa)
ei = min(len(qa), si + batch_size)
n = ei - si
sentence = np.ndarray( (n, qa_data['max_question_length']), dtype = 'int32')
answer = np.zeros( (n, len(qa_data['answer_vocab'])))
fc7 = np.ndarray( (n,4096) )
count = 0
for i in range(si, ei):
sentence[count,:] = qa[i]['question'][:]
answer[count, qa[i]['answer']] = 1.0
fc7_index = image_id_map[ int(qa[i]['image_id']) ]
fc7[count,:] = fc7_features[fc7_index][:]
count += 1
return sentence, answer, fc7
if __name__ == '__main__':
main()