-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New network trainers structure implemented
Vec structure checking implemented Vec to_string improved
- Loading branch information
Showing
14 changed files
with
298 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion
2
...al_core/simple_network/simple_network.cpp → src/neural_core/simple_network.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#include <neural_core/trainer/adam_trainer.h> | ||
|
||
AdamTrainer::AdamTrainer(SimpleNetwork* network, I_SimpleData* train_data, | ||
double beta1, double beta2, double alpha, | ||
int t_stop1, int t_stop2) : NetworkTrainer(network, train_data) | ||
{ | ||
Beta1 = beta1; | ||
Beta2 = beta2; | ||
Alpha = alpha; | ||
T_stop1 = t_stop1; | ||
T_stop2 = t_stop2; | ||
} | ||
|
||
void AdamTrainer::TrainNetwork(int checks) | ||
{ | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#ifndef NEURALNETWORKTRY_ADAMTRAINER_H | ||
#define NEURALNETWORKTRY_ADAMTRAINER_H | ||
|
||
#include "network_trainer.h" | ||
|
||
/// <summary> | ||
/// https://habr.com/ru/post/318970/ | ||
/// </summary> | ||
class AdamTrainer : NetworkTrainer | ||
{ | ||
/// <summary> Really small number </summary> | ||
const double eta = 1e-8; | ||
|
||
public: | ||
/// <summary> Training speed coefficient </summary> | ||
double Alpha; | ||
/// <summary> How much prev gradient is used </summary> | ||
double Beta1; | ||
/// <summary> How much more we edit unique weights </summary> | ||
double Beta2; | ||
|
||
/// <summary> How many times we artificially increase m </summary> | ||
int T_stop1; | ||
/// <summary> How many times we artificially increase v </summary> | ||
int T_stop2; | ||
|
||
/// <summary> Gradient average </summary> | ||
vec<vec<vec<double>>> m; | ||
/// <summary> Usage average </summary> | ||
vec<vec<vec<double>>> v; | ||
|
||
AdamTrainer(SimpleNetwork* network, I_SimpleData* train_data, | ||
double beta1 = .9, double beta2 = .999, double alpha = .001, | ||
int t_stop1 = 10, int t_stop2 = 1000); | ||
|
||
double ComputeError(int test_index) override; | ||
double ComputeAverageError() override; | ||
|
||
double TrainNetwork() override; | ||
|
||
void TrainNetwork(int checks) override; | ||
|
||
void TrainNetwork(int checks, std::ostream* stream, int logging_rate = 10) override; | ||
}; | ||
|
||
#endif //NEURALNETWORKTRY_ADAMTRAINER_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#ifndef NEURALNETWORKTRY_NETWORKTRAINER_H | ||
#define NEURALNETWORKTRY_NETWORKTRAINER_H | ||
|
||
#include "math/random_element.h" | ||
#include "neural_core/simple_data.h" | ||
#include "neural_core/simple_network.h" | ||
|
||
class NetworkTrainer : protected RandomElement<> | ||
{ | ||
protected: | ||
SimpleNetwork* Network; | ||
I_SimpleData* TrainData; | ||
|
||
public: | ||
NetworkTrainer(SimpleNetwork* network, I_SimpleData* train_data) | ||
{ | ||
Network = network; | ||
TrainData = train_data; | ||
} | ||
|
||
virtual double ComputeError(int test_index) = 0; | ||
virtual double ComputeAverageError() = 0; | ||
|
||
/// <summary> | ||
/// Do one network check | ||
/// </summary> | ||
virtual double TrainNetwork() = 0; | ||
|
||
/// <summary> | ||
/// Does several network checks. | ||
/// </summary> | ||
/// <param name="checks">Number of network checks</param> | ||
virtual void TrainNetwork(int checks) = 0; | ||
|
||
/// <summary> | ||
/// Does several network checks. | ||
/// Outputs log to the stream. | ||
/// </summary> | ||
/// <param name="checks">Number of network checks</param> | ||
/// <param name="stream">The stream to write the logs</param> | ||
/// <param name="logging_rate">Number of checks per log</param> | ||
virtual void TrainNetwork(int checks, std::ostream* stream, int logging_rate = 10) = 0; | ||
}; | ||
|
||
#endif //NEURALNETWORKTRY_NETWORKTRAINER_H |
Oops, something went wrong.