-
Notifications
You must be signed in to change notification settings - Fork 80
/
mnist.py
41 lines (33 loc) · 1.1 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from dense import Dense
from activations import Tanh
from losses import mse, mse_prime
from network import train, predict
def preprocess_data(x, y, limit):
# reshape and normalize input data
x = x.reshape(x.shape[0], 28 * 28, 1)
x = x.astype("float32") / 255
# encode output which is a number in range [0,9] into a vector of size 10
# e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
y = np_utils.to_categorical(y)
y = y.reshape(y.shape[0], 10, 1)
return x[:limit], y[:limit]
# load MNIST from server
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = preprocess_data(x_train, y_train, 1000)
x_test, y_test = preprocess_data(x_test, y_test, 20)
# neural network
network = [
Dense(28 * 28, 40),
Tanh(),
Dense(40, 10),
Tanh()
]
# train
train(network, mse, mse_prime, x_train, y_train, epochs=100, learning_rate=0.1)
# test
for x, y in zip(x_test, y_test):
output = predict(network, x)
print('pred:', np.argmax(output), '\ttrue:', np.argmax(y))