-
Notifications
You must be signed in to change notification settings - Fork 418
/
bnn_vi.py
142 lines (117 loc) · 4.96 KB
/
bnn_vi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import tensorflow as tf
from six.moves import range, zip
import numpy as np
import zhusuan as zs
from examples import conf
from examples.utils import dataset
@zs.meta_bayesian_net(scope="bnn", reuse_variables=True)
def build_bnn(x, layer_sizes, n_particles):
bn = zs.BayesianNet()
h = tf.tile(x[None, ...], [n_particles, 1, 1])
for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
w = bn.normal("w" + str(i), tf.zeros([n_out, n_in + 1]), std=1.,
group_ndims=2, n_samples=n_particles)
h = tf.concat([h, tf.ones(tf.shape(h)[:-1])[..., None]], -1)
h = tf.einsum("imk,ijk->ijm", w, h) / tf.sqrt(
tf.cast(tf.shape(h)[2], tf.float32))
if i < len(layer_sizes) - 2:
h = tf.nn.relu(h)
y_mean = bn.deterministic("y_mean", tf.squeeze(h, 2))
y_logstd = tf.get_variable("y_logstd", shape=[],
initializer=tf.constant_initializer(0.))
bn.normal("y", y_mean, logstd=y_logstd)
return bn
@zs.reuse_variables(scope="variational")
def build_mean_field_variational(layer_sizes, n_particles):
bn = zs.BayesianNet()
for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
w_mean = tf.get_variable(
"w_mean_" + str(i), shape=[n_out, n_in + 1],
initializer=tf.constant_initializer(0.))
w_logstd = tf.get_variable(
"w_logstd_" + str(i), shape=[n_out, n_in + 1],
initializer=tf.constant_initializer(0.))
bn.normal("w" + str(i), w_mean, logstd=w_logstd,
n_samples=n_particles, group_ndims=2)
return bn
def main():
tf.set_random_seed(1237)
np.random.seed(2345)
# Load UCI Boston housing data
data_path = os.path.join(conf.data_dir, "housing.data")
x_train, y_train, x_valid, y_valid, x_test, y_test = \
dataset.load_uci_boston_housing(data_path)
x_train = np.vstack([x_train, x_valid])
y_train = np.hstack([y_train, y_valid])
n_train, x_dim = x_train.shape
# Standardize data
x_train, x_test, _, _ = dataset.standardize(x_train, x_test)
y_train, y_test, mean_y_train, std_y_train = dataset.standardize(
y_train, y_test)
# Define model parameters
n_hiddens = [50]
# Build the computation graph
n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
x = tf.placeholder(tf.float32, shape=[None, x_dim])
y = tf.placeholder(tf.float32, shape=[None])
layer_sizes = [x_dim] + n_hiddens + [1]
w_names = ["w" + str(i) for i in range(len(layer_sizes) - 1)]
model = build_bnn(x, layer_sizes, n_particles)
variational = build_mean_field_variational(layer_sizes, n_particles)
def log_joint(bn):
log_pws = bn.cond_log_prob(w_names)
log_py_xw = bn.cond_log_prob('y')
return tf.add_n(log_pws) + tf.reduce_mean(log_py_xw, 1) * n_train
model.log_joint = log_joint
lower_bound = zs.variational.elbo(
model, {'y': y}, variational=variational, axis=0)
cost = lower_bound.sgvb()
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
infer_op = optimizer.minimize(cost)
# prediction: rmse & log likelihood
y_mean = lower_bound.bn["y_mean"]
y_pred = tf.reduce_mean(y_mean, 0)
rmse = tf.sqrt(tf.reduce_mean((y_pred - y) ** 2)) * std_y_train
log_py_xw = lower_bound.bn.cond_log_prob("y")
log_likelihood = tf.reduce_mean(zs.log_mean_exp(log_py_xw, 0)) - tf.log(
std_y_train)
# Define training/evaluation parameters
lb_samples = 10
ll_samples = 5000
epochs = 500
batch_size = 10
iters = (n_train-1) // batch_size + 1
test_freq = 10
# Run the inference
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(1, epochs + 1):
perm = np.random.permutation(x_train.shape[0])
x_train = x_train[perm, :]
y_train = y_train[perm]
lbs = []
for t in range(iters):
x_batch = x_train[t * batch_size:(t + 1) * batch_size]
y_batch = y_train[t * batch_size:(t + 1) * batch_size]
_, lb = sess.run(
[infer_op, lower_bound],
feed_dict={n_particles: lb_samples,
x: x_batch, y: y_batch})
lbs.append(lb)
print('Epoch {}: Lower bound = {}'.format(epoch, np.mean(lbs)))
if epoch % test_freq == 0:
test_rmse, test_ll = sess.run(
[rmse, log_likelihood],
feed_dict={n_particles: ll_samples,
x: x_test, y: y_test})
print('>> TEST')
print('>> Test rmse = {}, log_likelihood = {}'
.format(test_rmse, test_ll))
if __name__ == "__main__":
main()