-
Notifications
You must be signed in to change notification settings - Fork 31
/
mean_field_mrf.py
116 lines (93 loc) · 3.93 KB
/
mean_field_mrf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from scipy.special import expit as sigmoid
from scipy.stats import multivariate_normal
np.random.seed(42)
sns.set_style('whitegrid')
class image_denoising:
def __init__(self, img_binary, sigma=2, J=1):
#mean-field parameters
self.sigma = sigma #noise level
self.y = img_binary + self.sigma*np.random.randn(M, N) #y_i ~ N(x_i; sigma^2);
self.J = J #coupling strength (w_ij)
self.rate = 0.5 #update smoothing rate
self.max_iter = 15
self.ELBO = np.zeros(self.max_iter)
self.Hx_mean = np.zeros(self.max_iter)
def mean_field(self):
#Mean-Field VI
print("running mean-field variational inference...")
logodds = multivariate_normal.logpdf(self.y.flatten(), mean=+1, cov=self.sigma**2) - \
multivariate_normal.logpdf(self.y.flatten(), mean=-1, cov=self.sigma**2)
logodds = np.reshape(logodds, (M, N))
#init
p1 = sigmoid(logodds)
mu = 2*p1-1 #mu_init
a = mu + 0.5 * logodds
qxp1 = sigmoid(+2*a) #q_i(x_i=+1)
qxm1 = sigmoid(-2*a) #q_i(x_i=-1)
logp1 = np.reshape(multivariate_normal.logpdf(self.y.flatten(), mean=+1, cov=self.sigma**2), (M, N))
logm1 = np.reshape(multivariate_normal.logpdf(self.y.flatten(), mean=-1, cov=self.sigma**2), (M, N))
for i in tqdm(range(self.max_iter)):
muNew = mu
for ix in range(N):
for iy in range(M):
pos = iy + M*ix
neighborhood = pos + np.array([-1,1,-M,M])
boundary_idx = [iy!=0,iy!=M-1,ix!=0,ix!=N-1]
neighborhood = neighborhood[np.where(boundary_idx)[0]]
xx, yy = np.unravel_index(pos, (M,N), order='F')
nx, ny = np.unravel_index(neighborhood, (M,N), order='F')
Sbar = self.J*np.sum(mu[nx,ny])
muNew[xx,yy] = (1-self.rate)*muNew[xx,yy] + self.rate*np.tanh(Sbar + 0.5*logodds[xx,yy])
self.ELBO[i] = self.ELBO[i] + 0.5*(Sbar * muNew[xx,yy])
#end for
#end for
mu = muNew
a = mu + 0.5 * logodds
qxp1 = sigmoid(+2*a) #q_i(x_i=+1)
qxm1 = sigmoid(-2*a) #q_i(x_i=-1)
Hx = -qxm1*np.log(qxm1+1e-10) - qxp1*np.log(qxp1+1e-10) #entropy
self.ELBO[i] = self.ELBO[i] + np.sum(qxp1*logp1 + qxm1*logm1) + np.sum(Hx)
self.Hx_mean[i] = np.mean(Hx)
#end for
return mu
if __name__ == "__main__":
#load data
print("loading data...")
data = Image.open('./figures/bayes.bmp')
img = np.double(data)
img_mean = np.mean(img)
img_binary = +1*(img>img_mean) + -1*(img<img_mean)
[M, N] = img_binary.shape
mrf = image_denoising(img_binary, sigma=2, J=1)
mu = mrf.mean_field()
#generate plots
plt.figure()
plt.imshow(mrf.y)
plt.title("observed noisy image")
#plt.savefig('./figures/ising_vi_observed_image.png')
plt.show()
plt.figure()
plt.imshow(mu)
plt.title("after %d mean-field iterations" %mrf.max_iter)
#plt.savefig('./figures/ising_vi_denoised_image.png')
plt.show()
plt.figure()
plt.plot(mrf.Hx_mean, color='b', lw=2.0, label='Avg Entropy')
plt.title('Variational Inference for Ising Model')
plt.xlabel('iterations'); plt.ylabel('average entropy')
plt.legend(loc='upper right')
#plt.savefig('./figures/ising_vi_avg_entropy.png')
plt.show()
plt.figure()
plt.plot(mrf.ELBO, color='b', lw=2.0, label='ELBO')
plt.title('Variational Inference for Ising Model')
plt.xlabel('iterations'); plt.ylabel('ELBO objective')
plt.legend(loc='upper left')
#plt.savefig('./figures/ising_vi_elbo.png')
plt.show()