-
Notifications
You must be signed in to change notification settings - Fork 8
/
autoencoder_encodedstate_sequential.py
75 lines (63 loc) · 2.44 KB
/
autoencoder_encodedstate_sequential.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
'''
Visualizing the encoded state of a simple autoencoder created with the Keras Sequential API
with Keract.
'''
import keras
from keras.layers import Dense
from keras.datasets import mnist
from keras.models import Sequential
from keract import get_activations, display_activations
import matplotlib.pyplot as plt
from keras import backend as K
# Model configuration
img_width, img_height = 28, 28
initial_dimension = img_width * img_height
batch_size = 128
no_epochs = 1
validation_split = 0.2
verbosity = 1
encoded_dim = 50
# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()
# Reshape data
input_train = input_train.reshape(input_train.shape[0], initial_dimension)
input_test = input_test.reshape(input_test.shape[0], initial_dimension)
input_shape = (initial_dimension, )
# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')
# Normalize data
input_train = input_train / 255
input_test = input_test / 255
# Define the 'autoencoder' full model
autoencoder = Sequential()
autoencoder.add(Dense(encoded_dim, activation='relu', kernel_initializer='he_normal', input_shape=input_shape))
autoencoder.add(Dense(initial_dimension, activation='sigmoid'))
# Compile the autoencoder
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# Give us some insights
autoencoder.summary()
# Fit data
autoencoder.fit(input_train, input_train, epochs=no_epochs, batch_size=batch_size, validation_split=validation_split)
# =============================================
# Take a sample for visualization purposes
# =============================================
input_sample = input_test[:1]
reconstruction = autoencoder.predict([input_sample])
# =============================================
# Visualize input-->reconstruction
# =============================================
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(6, 3.5)
input_sample_reshaped = input_sample.reshape((img_width, img_height))
reconsstruction_reshaped = reconstruction.reshape((img_width, img_height))
axes[0].imshow(input_sample_reshaped)
axes[0].set_title('Original image')
axes[1].imshow(reconsstruction_reshaped)
axes[1].set_title('Reconstruction')
plt.show()
# =============================================
# Visualize encoded state with Keract
# =============================================
activations = get_activations(autoencoder, input_sample)
display_activations(activations, cmap="gray", save=False)