-
Notifications
You must be signed in to change notification settings - Fork 0
/
NN_001.java
118 lines (97 loc) · 5.35 KB
/
NN_001.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package perceptron_concept;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
public class NN_001 {
public static void main(String[] args) throws IOException {
System.out.println("running...\n");
//--- create neural network
int[] u = { 784, 25, 25, 25, 10 };
float learningRate = 0.0067f;
float bounceResRate = 50.0f;
float weightInitRange = 0.35f;
int runs = 10000;
int miniBatch = 8;
int networkInfoCheck = 10;
int dnn = u.length - 1, nns = 0, wnn = 0, inputs = u[0], output = u[dnn], correct = 0;
float ce = 0, ce2 = 0;
for (int n = 0; n < dnn + 1; n++) nns += u[n]; // num of neurons
for (int n = 1; n < dnn + 1; n++) wnn += u[n - 1] * u[n]; // num of weights
float[] neuron = new float[nns];
float[] gradient = new float[nns - inputs];
float[] weight = new float[wnn];
float[] delta = new float[wnn];
float[] target = new float[output];
//--- testdata "t10k-images.idx3-ubyte" - "t10k-labels.idx1-ubyte"
RandomAccessFile img = new RandomAccessFile(new File("C:\\mnist\\train-images.idx3-ubyte"), "r");
RandomAccessFile lbl = new RandomAccessFile(new File("C:\\mnist\\train-labels.idx1-ubyte"), "r");
img.seek(16);
lbl.seek(8);
//--- get pseudo random init weights
for (int n = 0, p = 314; n < wnn; n++)
weight[n] = (float)((p = p * 2718 % 2718281) / (2718281.0 * Math.E * Math.PI * weightInitRange));
//--- start training
for (int x = 1; x < runs + 1; x++){
//+----------- 1. MNIST as Inputs --------------------------------------+
for (int n = 0; n < inputs; ++n)
neuron[n] = img.read() / 255.0f;
int targetNum = lbl.read();
//+----------- 2. Feed Forward -----------------------------------------+
for (int i = 0, j = inputs, t = 0, w = 0; i < dnn; i++, t += u[i - 1], w += u[i] * u[i - 1])
for (int k = 0; k < u[i + 1]; k++, j++){
float net = gradient[j - inputs] = 0;
for (int n = t, m = w + k; n < t + u[i]; n++, m += u[i + 1])
net += neuron[n] * weight[m];
neuron[j] = i == dnn - 1 ? net : net > 0 ? net : 0;
}//--- k ends
//+------------ 3. NN prediction ---------------------------------------+
int outMaxPos = nns - output;
float outMaxVal = neuron[nns - output], scale = 0;
for (int i = nns - output + 1; i < nns; i++)
if (neuron[i] > outMaxVal){
outMaxPos = i; outMaxVal = neuron[i];
}
if (targetNum + nns - output == outMaxPos) correct++;
//+----------- 4. Loss / Error with Softmax and Cross Entropy ----------+
for (int n = nns - output; n != nns; n++)
scale += (float) Math.exp(neuron[n] - outMaxVal);
for (int n = nns - output, m = 0; n != nns; m++, n++)
neuron[n] = (float) Math.exp(neuron[n] - outMaxVal) / scale;
ce2 = (ce -= (float) Math.log(neuron[outMaxPos])) / x;
//+----------- 5. Backpropagation --------------------------------------+
target[targetNum] = 1.0f;
for (int i = dnn, j = nns - 1, ls = output, wd = wnn - 1, ws = wd, us = nns - output - 1, gs = nns - inputs - 1;
i != 0; i--, wd -= u[i + 1] * u[i + 0], us -= u[i], gs -= u[i + 1])
for (int k = 0; k != u[i]; k++, j--){
float gra = 0;
//--- first check if output or hidden, calc delta for both
if (i == dnn)
gra = target[--ls] - neuron[j];
else if(neuron[j] > 0)
for (int n = gs + u[i + 1]; n > gs; n--, ws--)
gra += weight[ws] * gradient[n];
else ws -= u[i + 1];
for (int n = us, w = wd - k; n > us - u[i - 1]; w -= u[i], n--)
delta[w] += gra * neuron[n];
gradient[j - inputs] = gra;
}
target[targetNum] = 0;
//+----------- 6. update Weights ---------------------------------------+
if ((x % miniBatch == 0) || (x == runs - 1)){
for (int m = 0; m < wnn; m++){
//--- bounce restriction
if (delta[m] * delta[m] > bounceResRate) continue;
//--- update weights
weight[m] += learningRate * delta[m];
delta[m] *= 0.67f;
}
} //--- batch end
if (x % (runs / networkInfoCheck) == 0)
System.out.println("runs: " + x + " accuracy: " + (correct * 100.0f / x));
} //--- runs end
System.out.println("\nneurons: " + nns + " weights: " + wnn + " batch: " + miniBatch);
System.out.println("accuracy: " + (correct * 100.0 / (runs * 1.0f)) + " cross entropy: " + ce2);
System.out.println("correct: "+(correct) + " incorrect: " + (runs - correct));
img.close(); lbl.close();
}
}