-
Notifications
You must be signed in to change notification settings - Fork 31
/
keras_mdn.py
193 lines (152 loc) · 5.88 KB
/
keras_mdn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
from keras.layers import concatenate, Input
from keras.layers import Dense, Activation, Dropout, Flatten
from keras.layers import BatchNormalization
from keras import regularizers
from keras import backend as K
from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
from keras.callbacks import LearningRateScheduler
from keras.callbacks import EarlyStopping
from sklearn.datasets import make_blobs
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import normalized_mutual_info_score
from sklearn.model_selection import train_test_split
import math
import matplotlib.pyplot as plt
import matplotlib.cm as cm
tf.keras.utils.set_random_seed(42)
SAVE_PATH = "/content/drive/MyDrive/Colab Notebooks/data/"
def scheduler(epoch, lr):
if epoch < 4:
return lr
else:
return lr * tf.math.exp(-0.1)
def generate_data(N):
pi = np.array([0.2, 0.4, 0.3, 0.1])
mu = [[2,2], [-2,2], [-2,-2], [2,-2]]
std = [[0.5,0.5], [1.0,1.0], [0.5,0.5], [1.0,1.0]]
x = np.zeros((N,2), dtype=np.float32)
y = np.zeros((N,2), dtype=np.float32)
z = np.zeros((N,1), dtype=np.int32)
for n in range(N):
k = np.argmax(np.random.multinomial(1, pi))
x[n,:] = np.random.multivariate_normal(mu[k], np.diag(std[k]))
y[n,:] = mu[k]
z[n,:] = k
#end for
z = z.flatten()
return x, y, z, pi, mu, std
def tf_normal(y, mu, sigma):
y_tile = K.stack([y]*num_clusters, axis=1) #[batch_size, K, D]
result = y_tile - mu
sigma_tile = K.stack([sigma]*data_dim, axis=-1) #[batch_size, K, D]
result = result * 1.0/(sigma_tile+1e-8)
result = -K.square(result)/2.0
oneDivSqrtTwoPI = 1.0/math.sqrt(2*math.pi)
result = K.exp(result) * (1.0/(sigma_tile + 1e-8))*oneDivSqrtTwoPI
result = K.prod(result, axis=-1) #[batch_size, K] iid Gaussians
return result
def NLLLoss(y_true, y_pred):
out_mu = y_pred[:,:num_clusters*data_dim]
out_sigma = y_pred[:,num_clusters*data_dim : num_clusters*(data_dim+1)]
out_pi = y_pred[:,num_clusters*(data_dim+1):]
out_mu = K.reshape(out_mu, [-1, num_clusters, data_dim])
result = tf_normal(y_true, out_mu, out_sigma)
result = result * out_pi
result = K.sum(result, axis=1, keepdims=True)
result = -K.log(result + 1e-8)
result = K.mean(result)
return tf.maximum(result, 0)
#generate data
X_data, y_data, z_data, pi_true, mu_true, sigma_true = generate_data(4096)
data_dim = X_data.shape[1]
num_clusters = len(mu_true)
num_train = 3500
X_train, X_test, y_train, y_test = X_data[:num_train,:], X_data[num_train:,:], y_data[:num_train,:], y_data[num_train:,:]
z_train, z_test = z_data[:num_train], z_data[num_train:]
#visualize data
plt.figure()
plt.scatter(X_train[:,0], X_train[:,1], c=z_train, cmap=cm.bwr)
plt.title('training data')
plt.show()
#plt.savefig(SAVE_PATH + '/mdn_training_data.png')
#training params
batch_size = 128
num_epochs = 128
#model parameters
hidden_size = 32
weight_decay = 1e-4
#MDN architecture
input_data = Input(shape=(data_dim,))
x = Dense(32, activation='relu')(input_data)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
mu = Dense(num_clusters * data_dim, activation='linear')(x) #cluster means
sigma = Dense(num_clusters, activation=K.exp)(x) #diagonal cov
pi = Dense(num_clusters, activation='softmax')(x) #mixture proportions
out = concatenate([mu, sigma, pi], axis=-1)
model = Model(input_data, out)
model.compile(
loss=NLLLoss,
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"]
)
model.summary()
#define callbacks
file_name = SAVE_PATH + 'mdn-weights-checkpoint.h5'
checkpoint = ModelCheckpoint(file_name, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
reduce_lr = LearningRateScheduler(scheduler, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.01, patience=16, verbose=1)
#tensor_board = TensorBoard(log_dir='./logs', write_graph=True)
callbacks_list = [checkpoint, reduce_lr, early_stopping]
hist = model.fit(X_train, y_train, batch_size=batch_size, epochs=num_epochs, callbacks=callbacks_list, validation_split=0.2, shuffle=True, verbose=2)
y_pred = model.predict(X_test)
mu_pred = y_pred[:,:num_clusters*data_dim]
mu_pred = np.reshape(mu_pred, [-1, num_clusters, data_dim])
sigma_pred = y_pred[:,num_clusters*data_dim : num_clusters*(data_dim+1)]
pi_pred = y_pred[:,num_clusters*(data_dim+1):]
z_pred = np.argmax(pi_pred, axis=-1)
rand_score = adjusted_rand_score(z_test, z_pred)
print("adjusted rand score: ", rand_score)
nmi_score = normalized_mutual_info_score(z_test, z_pred)
print("normalized MI score: ", nmi_score)
mu_pred_list = []
sigma_pred_list = []
for label in np.unique(z_pred):
z_idx = np.where(z_pred == label)[0]
mu_pred_lbl = np.mean(mu_pred[z_idx,label,:], axis=0)
mu_pred_list.append(mu_pred_lbl)
sigma_pred_lbl = np.mean(sigma_pred[z_idx,label], axis=0)
sigma_pred_list.append(sigma_pred_lbl)
#end for
print("true means:")
print(np.array(mu_true))
print("predicted means:")
print(np.array(mu_pred_list))
print("true sigmas:")
print(np.array(sigma_true))
print("predicted sigmas:")
print(np.array(sigma_pred_list))
#generate plots
plt.figure()
plt.scatter(X_test[:,0], X_test[:,1], c=z_pred, cmap=cm.bwr)
plt.scatter(np.array(mu_pred_list)[:,0], np.array(mu_pred_list)[:,1], s=100, marker='x', lw=4.0, color='k')
plt.title('test data')
#plt.savefig('./figures/mdn_test_data.png')
plt.figure()
plt.plot(hist.history['loss'], 'b', lw=2.0, label='train')
plt.plot(hist.history['val_loss'], '--r', lw=2.0, label='val')
plt.title('Mixture Density Network')
plt.xlabel('Epochs')
plt.ylabel('Negative Log Likelihood Loss')
plt.legend(loc='upper left')
#plt.savefig('./figures/mdn_loss.png')