forked from emdodds/DictLearner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TopoSparsenet.py
156 lines (130 loc) · 5.97 KB
/
TopoSparsenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 7 14:23:14 2016
@author: Tep & Eric
"""
import DictLearner
import sparsenet
import numpy as np
import matplotlib.pyplot as plt
class TopoSparsenet(sparsenet.Sparsenet):
def __init__(self, data, dict_shape=None, lamb=0.15, lamb_2=0.01,
sigma=1, **kwargs):
self.lamb_2 = lamb_2
self.dict_shape = dict_shape
self.nunits = int(np.prod(self.dict_shape))
self.sigma = sigma
self.g = self.layer_two_weights()
super().__init__(data, int(np.prod(self.dict_shape)), lamb=lamb, **kwargs)
def infer(self, X, infplot=False):
acts = np.zeros((self.nunits,self.batch_size))
if infplot:
error_hist = np.zeros(self.niter)
first_hist = np.zeros(self.niter)
second_hist = np.zeros(self.niter)
phi_sq = self.Q.dot(self.Q.T)
QX = self.Q.dot(X)
for k in range(self.niter):
pooled = self.g @ acts**2
da2_da1 = 2 * self.layer_two_deriv(pooled) * (self.g @ acts)
da_dt = QX - phi_sq @ acts - self.lamb*self.dSda(acts) - self.lamb_2*da2_da1
acts = acts+self.infrate*(da_dt)
if infplot:
error_hist[k] = np.mean((X.T-np.dot(acts.T,self.Q))**2)
first_hist[k] = np.mean(np.abs(acts))
second_hist[k] = np.mean(self.layer_two_measure(pooled))
if infplot:
plt.figure()
plt.plot(error_hist,'b')
plt.plot(first_hist,'g')
plt.plot(second_hist, 'r')
return acts, None, None
def layer_two_measure(self, pooled_acts):
"""For now, just takes the square root."""
return np.sqrt(pooled_acts)
def layer_two_deriv(self, pooled_acts):
return 1/(2*np.sqrt(pooled_acts) + 0.01)
def distance(self, i, j):
""" This function measures the distance between element i and j. The distance
here is the distance between element i and j once the row vector has been
reshaped into a square matrix, treating the dictionary as a torus globally."""
rows, cols = self.dict_shape
rowi = i // cols
coli = i % cols
rowj = j // cols
colj = j % cols
# global topology is a torus
rowj = [rowj - rows, rowj, rowj + rows]
colj = [colj - cols, colj, colj + cols]
dist = []
for r in rowj:
for c in colj:
dist.append((rowi - r)**2 + (coli - c)**2)
return np.min(dist)
def layer_two_weights(self):
"""This is currently only working for the case when (# of layer 2
units) = (# of layer 1 units) """
g = np.zeros((self.nunits, self.nunits))
sigsquared = self.sigma**2
for i in range(self.nunits):
for j in range(self.nunits):
g[i, j] = np.exp(-self.distance(i, j)/(2 * sigsquared))
return g
def block_membership(self, i, j, width=5):
"""This returns 1 if j is in the ith block, otherwise 0. Currently only
works for square dictionaries."""
# TODO: I think there's a bug here that makes the boundary conditions
# and the sizes wrong
size = self.dict_shape[0]
if size != self.dict_shape[1]:
raise NotImplementedError
i = [i // size, i % size]
j = [j // size, j % size]
if (abs((i[0]%size)-(j[0]%size)) % (size-1) < width) and (abs((i[1]%size)-(j[1]%size)) % (size-1) < width):
return 1
else:
return 0
def set_blocks(self, width=5):
"""Change the topography by making each second layer unit respond to
a square block of layer one with given width. g becomes binary."""
self.g = np.zeros_like(self.g)
nunits = np.prod(self.dict_shape)
for i in range(nunits):
for j in range(nunits):
self.g[i, j] = self.block_membership(i, j, width)
def binarize_g(self, thresh=1/2, width=None):
if width is not None:
thresh = np.exp(-width**2/(2*self.sigma**2))
self.g = np.array(self.g >= thresh, dtype=int)
def show_dict(self, stimset=None, cmap='RdBu', subset=None, square=False, savestr=None):
"""Plot an array of tiled dictionary elements. The 0th element is in the top right."""
stimset = stimset or self.stims
if subset is not None:
Qs = self.Q[subset]
else:
Qs = self.Q
if cmap=='RdBu':
Qs=-Qs
array = stimset.stimarray(Qs[::-1], layout=self.dict_shape)
plt.figure()
arrayplot = plt.imshow(array,interpolation='nearest', cmap=cmap, aspect='auto', origin='lower')
plt.axis('off')
plt.colorbar()
if savestr is not None:
plt.savefig(savestr, bbox_inches='tight')
return arrayplot
def set_params(self, params):
try:
(self.learnrate, self.infrate, self.niter, self.lamb, self.lamb_2,
self.measure, self.var_goal, self.gains, self.variances,
self.var_eta, self.gain_rate) = params
except ValueError:
(self.learnrate, self.infrate, self.niter, self.lamb,
self.measure, self.var_goal, self.gains, self.variances,
self.var_eta, self.gain_rate) = params
def get_param_list(self):
return (self.learnrate, self.infrate, self.niter, self.lamb, self.lamb_2,
self.measure, self.var_goal, self.gains, self.variances,
self.var_eta, self.gain_rate)
def sort(self, *args, **kwargs):
print("The topographic order is meaningful, don't sort it away!")