-
Notifications
You must be signed in to change notification settings - Fork 8
/
fast_obs.m
273 lines (216 loc) · 9.22 KB
/
fast_obs.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
function [xf,xs,outpars,LL] = fast_obs(y,M,p,r,S,pars,control,equal,fixed,scale)
%--------------------------------------------------------------------------
% Title: Parameter estimation and inference in state-space models with
% regime switching (switching observations) assuming regimes known
%
% Function: Infer hidden state vectors and regimes by switching Kalman
% filtering/smoothing (aka Hamilton filtering or Kim filtering)
% and estimate model parameters by maximum likelihood (EM algorithm).
%
% Usage: [Mf,Ms,Sf,Ss,xf,xs,outpars,LL] = ...
% fast_obs(y,M,p,S,pars,control,equal,fixed,scale)
%
% Inputs:
% y - Time series (dimension NxT)
% M - number of regimes
% p - order of VAR model for state vector
% pars - optional structure with fields
% A - Initial estimates of VAR matrices A(l,j) in system equation
% x(t,j) = sum(l=1:p) A(l,j) x(t-l,j) + v(t,j), j=1:M (dimension rxrxpxM)
% C - Initial estimates of observation matrices C(j) in equation
% y(t) = C(j) x(t,j) + w(t), j=1:M (dimension NxrxM)
% Q - Initial estimates of state noise covariance Cov(v(t,j)) (dimension rxrxM)
% R - Pilot estimate of observation noise covariance Cov(w(t)) (dimension NxN)
% mu - Pilot estimate of mean state mu(j)=E(x(t,j)) for t=1:p (dimension rxM)
% Sigma - Pilot estimate of covariance Sigma(j)=Cov(x(t,j)) for t=1:p (dimension rxrxM)
% S - regime sequence (length T)
% control - optional struct variable with fields:
% 'eps': tolerance for EM termination; defaults to 1e-8
% 'ItrNo': number of EM iterations; defaults to 1000
% 'beta0': initial inverse temperature parameter for deterministic annealing; default 1
% 'betarate': decay rate for temperature; default 1
% 'safe': if true, regularizes variance matrices to be well-conditioned
% before taking inverse. If false, no regularization (faster but less safe)
% 'abstol': absolute tolerance for eigenvalues in matrix inversion (only effective if safe = true)
% 'reltol': relative tolerance for eigenvalues in matrix inversion
% = inverse condition number (only effective if safe = true)
% equal - optional struct variable with fields:
% 'A': if true, VAR transition matrices A(l,j) are equal across regimes j=1,...,M
% 'C': if true, observation matrices C(j) are equal across regimes
% 'Q': if true, VAR innovation matrices Q(j) are equal across regimes
% 'mu': if true, initial mean state vectors mu(j) are equal across regimes
% 'Sigma': if true, initial variance matrices Sigma(j) are equal across regimes
% fixed - optional struct variable with fields 'A','C','Q','R','mu','Sigma'.
% If not empty, each field must contain a matrix with 2 columns, the first for
% the location of fixed coefficients and the second for their values.
% scale - optional struct variable with fields:
% 'A': upper bound for norm of eigenvalues of A matrices. Must be in (0,1).
% 'C': value of the (euclidean) column norms of the matrices C(j). Must be positive.
%
% Outputs:
% Mf - State probability estimated by switching Kalman Filter
% Ms - State probability estimated by switching Kalman Smoother
% Sf - Estimated states (Kalman Filter)
% Ss - Estimated states (Kalman Smoother)
% xf - Filtered state vector
% xs - Smoothed state vector
% outpars - structure with fields
% A - Estimated system matrix
% C - Estimated observation matrix
% Q - Estimated state noise cov
% R - Estimated observation noise cov
% mu - Estimated initial mean of state vector
% Sigma - Estimated initial variance of state vector
% LL - Log-likelihood
% Variables:
% T = length of signal
% N = dimension of observation vector
% r = dimension of state vector
% M = number of regimes/states
%
% Author: David Degras
% University of Massachusetts Boston
%
% Contributors: Ting Chee Ming, [email protected]
% Siti Balqis Samdin
% Centre for Biomedical Engineering, Universiti Teknologi Malaysia.
%
% Version date: February 7, 2021
%--------------------------------------------------------------------------
%-------------------------------------------------------------------------%
% Initialization %
%-------------------------------------------------------------------------%
narginchk(5,10);
% Data dimensions
[N,T] = size(y);
% x(t,j): state vector for j-th process at time t (size r0)
% x(t) = x(t,1),...,x(t,M): state vector for all processes at time t (size M*r0)
% X(t,j) = x(t,j),...,x(t-p+1,j)): state vector for j-th process at times t,...,t-p+1 (size r=p*r0)
% X(t) = x(t,1),...,x(t,M): state vector for all processes at times t,...,t-p+1 (size M*p*r0)
% We assume t the initial vectors x(1),...,x(1-p+1) are iid ~ N(mu,Sigma)
% Check that time series has same length as regime sequence
assert(size(y,2) == numel(S));
% Check that all regime values S(t) are in 1:M
assert(all(ismember(S,1:M)));
% Data centering
y = y - mean(y,2);
%@@@@ Initialize optional arguments if not specified
if ~exist('fixed','var')
fixed = struct();
end
if ~exist('equal','var')
equal = struct();
end
if ~exist('control','var')
control = struct();
end
if ~exist('scale','var')
scale = struct();
end
%@@@@ Initialize estimators by OLS if not specified @@@@%
pars0 = struct('A',[], 'C',[], 'Q',[], 'R',[], 'mu',[], 'Sigma',[], ...
'Pi',[], 'Z',[]);
if exist('pars','var') && isstruct(pars)
fname = fieldnames(pars0);
for i = 1:8
name = fname{i};
if isfield(pars,name)
pars0.(name) = pars.(name);
end
end
end
pars = pars0;
if any(structfun(@isempty,pars))
pars = init_obs(y,M,p,r,[],control,equal,fixed,scale);
end
Pi = zeros(M,1);
Pi(S(1)) = 1;
pars.Pi = Pi;
fixed.Pi = [];
Z = crosstab(S(1:T-1),S(2:T));
Z = Z ./ sum(Z,2);
Z(isnan(Z)) = 1/M;
pars.Z = Z;
fixed.Z = [];
% Preprocess input arguments
[pars,control,equal,fixed,scale,skip] = ...
preproc_obs(M,N,p,r,pars,control,equal,fixed,scale);
abstol = control.abstol;
reltol = control.reltol;
eps = control.eps;
ItrNo = control.ItrNo;
verbose = control.verbose;
safe = control.safe;
% Parameter sizes
% A: r x pr x M, C: N x r x M, Q: r x r x M, R: N x N,
% mu: r x M, Sigma: r x r x M
% (Only size of A is changed)
%@@@@ Initialize other quantities
LL = zeros(1,ItrNo); % Log-likelihood
LLbest = -Inf;
LLflag = 0; % counter for convergence of of log-likelihood
sum_yy = y * y.'; % sum(t=1:T) y(t)*y(t)'
Ms = zeros(M,T); % P(S(t)=j), for use in Q-function
% Ms(sub2ind([M,T],S,T)) = 1;
sum_Ms2 = zeros(M);
for j = 1:M
Ms(j,S == j) = 1;
for k = 1:M
sum_Ms2(j,k) = sum(S(1:end-1) == j & S(2:end) == k);
end
end
for i = 1:ItrNo
%-------------------------------------------------------------------------%
% Filtering and smoothing + E-step %
%-------------------------------------------------------------------------%
[xf,xs,x0,P0,L,sum_CP,sum_MP,sum_Mxy,sum_P,sum_Pb] = ...
kfs_obs(y,M,p,r,S,pars,safe,abstol,reltol);
% Log-likelihood
LL(i) = L;
if verbose
fprintf('Iteration-%d Log-likelihood = %g\n',i,LL(i));
Qval = Q_obs(pars,Ms,P0,sum_CP,sum_MP,sum_Ms2,sum_Mxy,...
sum_P,sum_Pb,sum_yy,x0);
fprintf('Iteration-%d Q-function = %g (before M-step)\n',i,Qval);
end
% Check if current solution is best to date
if i == 1 || LL(i) > LLbest
LLbest = LL(i);
xfbest = xf;
xsbest = xs;
outpars = pars;
end
% Monitor convergence of log-likelihood
if i>1 && (LL(i)-LL(i-1)) < (eps * abs(LL(i-1)))
LLflag = LLflag + 1;
else
LLflag = 0;
end
% Terminate EM algorithm if no sufficient reduction in log-likelihood
% for 5 successive iterations
if LLflag == 5
break;
end
%-------------------------------------------------------------------------%
% M-step %
%-------------------------------------------------------------------------%
pars = M_obs(pars,Ms,P0,sum_CP,sum_MP,sum_Ms2,sum_Mxy,...
sum_P,sum_Pb,sum_yy,x0,control,equal,fixed,scale,skip);
if verbose
Qval = Q_obs(pars,Ms,P0,sum_CP,sum_MP,sum_Ms2,sum_Mxy,...
sum_P,sum_Pb,sum_yy,x0);
fprintf('Iteration-%d Q-function = %g (after M-step)\n',i,Qval);
end
end % END MAIN LOOP
%-------------------------------------------------------------------------%
% Output %
%-------------------------------------------------------------------------%
% Return best estimates (i.e. with highest log-likelihood)
% after reshaping them in compact form
outpars.A = reshape(outpars.A,r,r,p,M);
outpars.Pi = Pi;
outpars.Z = Z;
xf = xfbest;
xs = xsbest;
LL = LL(1:i);
end