-
Notifications
You must be signed in to change notification settings - Fork 0
/
lrSolver_Demo_tc.m
80 lines (63 loc) · 1.86 KB
/
lrSolver_Demo_tc.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
function lrSolver_Demo_tc()
% Exercise -- Logistic Regression Solver
clear all; close all; clc
%% generate data
nsamples = 200;
% training data
[x, y] = tcdataGenerator(nsamples, 0.5, 'normal');
y(find(y==-1)) = 0;
% testing data
[xt, yt] = tcdataGenerator(nsamples, 0.5, 'normal');
yt(find(yt==-1)) = 0;
%% Logistic Regression Solver
% FastDescent ConjugateGradient Newton FixedNewton DFP BFGS SGD
option.C = 1;
option.debug = 1;
options.epochs = 3;
options.minibatch = 200;
options.alpha = 1e-1;
options.momentum = .95;
[theta, cost] = lrSGD(x, y, option)
%% Visualize Results
figure(1)
subplot(121)
xmin = min(x(:, 1))-1;
xmax = max(x(:, 1))+1;
data_pos = x(find(y==1),:);
data_neg = x(find(y==0),:);
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
hold on
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-theta(1)-margin*theta(2))/theta(3), 'r-', 'LineWidth', 2);
hold off
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
option.C = 1;
[theta_, cost] = lrLBFGS(x, y, option)
%% Visualize Results
figure(1)
subplot(122)
xmin = min(x(:))-1;
xmax = max(x(:))+1;
data_pos = x(find(y==1),:);
data_neg = x(find(y==0),:);
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
hold on
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-theta(1)-margin*theta(2))/theta(3), 'r-', 'LineWidth', 2);
hold off
%% predict
xx = [ones(size(x, 1), 1), x];
h = sigmoid(xx, theta);
p = ones(size(h));
p(find(h<0.5)) = 0;
acc = sum(p==y)/length(p);
disp(['train acc: ', num2str(acc)]);
h = sigmoid(xx, theta);
p = ones(size(h));
p(find(h<0.5)) = 0;
acc = sum(p==yt)/length(p);
disp(['test acc: ', num2str(acc)]);