forked from emdodds/DictLearner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LCALearner.py
206 lines (181 loc) · 9.71 KB
/
LCALearner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 20 15:48:46 2015
@author: Eric Dodds
(Inference method adapted from code by Jesse Livezey)
Dictionary learner that uses LCA for inference and gradient descent for learning.
(Intended for static inputs)
"""
import numpy as np
import matplotlib.pyplot as plt
from DictLearner import DictLearner
import pickle
try:
import LCAonGPU
except ImportError:
print("Unable to load GPU implementation. Only CPU inference available.")
class LCALearner(DictLearner):
def __init__(self, data, nunits, learnrate=None, theta = 0.022,
batch_size = 100, infrate=.01,
niter=300, min_thresh=0.4, adapt=0.95, tolerance = .01, max_iter=4,
softthresh = False, datatype = "image", moving_avg_rate=.001,
pca = None, stimshape = None, paramfile = None, gpu=False):
"""
An LCALearner is a dictionary learner (DictLearner) that uses a Locally Competitive Algorithm (LCA) for inference.
By default the LCALearner optimizes for sparsity as measured by the L0 pseudo-norm of the activities of the units
(i.e. the usages of the dictionary elements).
Args:
data: data presented to LCALearner for estimating with LCA
nunits: number of units in thresholding circuit = number dictionary elements
learnrate: rate for mean-squared error part of learning rule
theta: rate for orthogonality constraint part of learning rule
batch_size: number of data presented for inference per learning step
infrate: rate for evolving the dynamical equation in inference (size of each step)
niter: number of steps in inference (if tolerance is small, chunks of this many iterations are repeated until tolerance is satisfied)
min_thresh: thresholds are reduced during inference no lower than this value. sometimes called lambda, multiplies sparsity constraint in objective function
adapt: factor by which thresholds are multipled after each inference step
tolerance: inference ceases after mean-squared error falls below tolerance
max_iter: maximum number of chunks of inference (each chunk having niter iterations)
softthresh: if True, optimize for L1-sparsity
datatype: image or spectro
pca: pca object for inverse-transforming data if used in PC representation
stimshape: original shape of data (e.g., before unrolling and PCA)
paramfile: a pickle file with dictionary and error history is stored here
gpu: whether or not to use the GPU implementation of
"""
learnrate = learnrate or 1./batch_size
self.infrate = infrate
self.niter = niter
self.min_thresh = min_thresh
self.adapt = adapt
self.softthresh = softthresh
self.tolerance = tolerance
self.max_iter = max_iter
self.gpu = gpu
self.meanacts = np.zeros(nunits)
super().__init__(data, learnrate, nunits, paramfile = paramfile, theta=theta, moving_avg_rate=moving_avg_rate,
stimshape=stimshape, datatype=datatype, batch_size=batch_size, pca=pca)
def show_oriented_dict(self, batch_size=None, *args, **kwargs):
"""Display tiled dictionary as in DictLearn.show_dict(), but with elements inverted
if their activities tend to be negative."""
if batch_size is None:
means = self.meanacts
else:
if batch_size == 'all':
X = self.stims.data.T
else:
X = self.stims.rand_stim(batch_size)
means = np.mean(self.infer(X)[0],axis=1)
toflip = means < 0
realQ = self.Q
self.Q[toflip] = -self.Q[toflip]
result = self.show_dict(*args, **kwargs)
self.Q = realQ
return result
def infer_cpu(self, X, infplot=False, tolerance=None, max_iter = None):
"""Infer sparse approximation to given data X using this LCALearner's
current dictionary. Returns coefficients of sparse approximation.
Optionally plot reconstruction error vs iteration number.
The instance variable niter determines for how many iterations to evaluate
the dynamical equations. Repeat this many iterations until the mean-squared error
is less than the given tolerance or until max_iter repeats."""
tolerance = tolerance or self.tolerance
max_iter = max_iter or self.max_iter
ndict = self.Q.shape[0]
nstim = X.shape[-1]
u = np.zeros((nstim, ndict))
s = np.zeros_like(u)
ci = np.zeros_like(u)
# c is the overlap of dictionary elements with each other, minus identity (i.e., ignore self-overlap)
c = self.Q.dot(self.Q.T)
for i in range(c.shape[0]):
c[i,i] = 0
# b[i,j] is overlap of stimulus i with dictionary element j
b = (self.Q.dot(X)).T
# initialize threshold values, one for each stimulus, based on average response magnitude
thresh = np.absolute(b).mean(1)
thresh = np.array([np.max([th, self.min_thresh]) for th in thresh])
if infplot:
errors = np.zeros(self.niter)
allerrors = np.array([])
error = tolerance+1
outer_k = 0
while(error>tolerance and ((max_iter is None) or outer_k<max_iter)):
for kk in range(self.niter):
# ci is the competition term in the dynamical equation
ci[:] = s.dot(c)
u[:] = self.infrate*(b-ci) + (1.-self.infrate)*u
if np.max(np.isnan(u)):
raise ValueError("Internal variable blew up at iteration " + str(kk))
if self.softthresh:
s[:] = np.sign(u)*np.maximum(0.,np.absolute(u)-thresh[:,np.newaxis])
else:
s[:] = u
s[np.absolute(s) < thresh[:,np.newaxis]] = 0
if infplot:
errors[kk] = np.mean(self.compute_errors(s.T,X))
thresh = self.adapt*thresh
thresh[thresh<self.min_thresh] = self.min_thresh
error = np.mean((X.T - s.dot(self.Q))**2)
outer_k = outer_k+1
if infplot:
allerrors = np.concatenate((allerrors,errors))
if infplot:
plt.figure(3)
plt.clf()
plt.plot(allerrors)
return s.T, errors
return s.T, u.T, thresh
def infer(self, X, infplot=False, tolerance=None, max_iter = None):
if self.gpu:
# right now there is no support for multiple blocks of iterations, stopping after error crosses threshold, or plots monitoring inference
return LCAonGPU.infer(self, X.T)
else:
return self.infer_cpu(X, infplot, tolerance, max_iter)
def test_inference(self, niter=None):
temp = self.niter
self.niter = niter or self.niter
X = self.stims.rand_stim()
s = self.infer(X, infplot=True)[0]
self.niter = temp
print("Final SNR: " + str(self.snr(X,s)))
return s
def adjust_rates(self, factor):
"""Multiply the learning rate by the given factor."""
self.learnrate = factor*self.learnrate
#self.infrate = self.infrate*factor # this is bad, but NC seems to have done it
def set_params(self, params):
(self.learnrate, self.theta, self.min_thresh, self.infrate,
self.niter, self.adapt, self.max_iter, self.tolerance) = params
def get_param_list(self):
return (self.learnrate, self.theta, self.min_thresh, self.infrate,
self.niter, self.adapt, self.max_iter, self.tolerance)
def load(self, filename=None):
"""Loads the parameters that were saved. For older files when I saved less, loads what I saved then."""
self.paramfile = filename
try:
super().load(filename)
return
except:
# This is all for backwards-compatibility with files I saved before I started saving as many statistics
try:
with open(filename, 'rb') as f:
self.Q, params, histories = pickle.load(f)
self.learnrate, self.theta, self.min_thresh, self.infrate, self.niter, self.adapt, self.max_iter, self.tolerance = params
try:
self.errorhist, self.L0acts, self.L0hist, self.L1acts, self.L1hist, self.corrmatrix_ave = histories
except ValueError:
print('Loading old file. Correlation matrix not available.')
try:
self.errorhist, self.L0acts, self.L0hist, self.L1acts, self.L1hist = histories
except ValueError:
print('Loading old file. Activity histories not available.')
try:
self.errorhist, self.L0acts, self.L1acts = histories
except ValueError:
print("Loading old file. Moving average activities not available.")
self.errorhist = histories
except ValueError:
print("Loading very old file. Only dictionary and error history available.")
with open(filename, 'rb') as f:
self.Q, self.errorhist = pickle.load(f)