diff --git a/examples/xor_classification.cpp b/examples/xor_classification.cpp index 3e23cfd..9a0e238 100644 --- a/examples/xor_classification.cpp +++ b/examples/xor_classification.cpp @@ -25,7 +25,7 @@ int main() { .build(); // clang-format on - SGD32 optimizer = SGD(mlp->parameters(), 0.1); + Adam32 optimizer = Adam(mlp->parameters(), 0.1); Loss::Function32 lossFunction = Loss::CrossEntropy; // ------ TRAINING THE NETWORK ------- // diff --git a/include/nn/optimizers/Adam.hpp b/include/nn/optimizers/Adam.hpp index c79986c..224e169 100644 --- a/include/nn/optimizers/Adam.hpp +++ b/include/nn/optimizers/Adam.hpp @@ -28,11 +28,8 @@ template class Adam : public Optimizer { T _eps; size_t _timestep; - std::unordered_map *, T> _firstMoment; - std::unordered_map *, T> _secondMoment; - - T getFirstMoment(const ValuePtr &v); - T getSecondMoment(const ValuePtr &v); + std::vector _firstMoments; + std::vector _secondMoments; public: Adam(std::vector> params, T learningRate, T b1 = 0.9, T b2 = 0.999, T eps = 1e-8); @@ -42,23 +39,28 @@ template class Adam : public Optimizer { template Adam::Adam(std::vector> params, T learningRate, T b1, T b2, T eps) : Optimizer(params, learningRate) { - _timestep = 0; _b1 = b1; _b2 = b2; _eps = eps; + + _timestep = 0; + _firstMoments.resize(params.size(), 0); + _secondMoments.resize(params.size(), 0); } template void Adam::step() { _timestep++; - for (const ValuePtr ¶m : this->_parameters) { + for (size_t i = 0; i < this->_parameters.size(); ++i) { + const ValuePtr ¶m = this->_parameters[i]; + T gradient = param->getGradient(); - T firstMoment = _b1 * getFirstMoment(param) + (1 - _b1) * gradient; - T secondMoment = _b2 * getSecondMoment(param) + (1 - _b2) * gradient * gradient; + T firstMoment = _b1 * _firstMoments[i] + (1 - _b1) * gradient; + T secondMoment = _b2 * _secondMoments[i] + (1 - _b2) * gradient * gradient; - _firstMoment.insert({param.get(), firstMoment}); - _secondMoment.insert({param.get(), secondMoment}); + _firstMoments[i] = firstMoment; + _secondMoments[i] = secondMoment; T firstMomentHat = firstMoment / (1 - pow(_b1, _timestep)); T secondMomentHat = secondMoment / (1 - pow(_b2, _timestep)); @@ -67,22 +69,4 @@ template void Adam::step() { } } -template T Adam::getFirstMoment(const ValuePtr &v) { - auto moment = _firstMoment.find(v.get()); - if (moment == _firstMoment.end()) { - _firstMoment.insert({v.get(), 0}); - return 0; - } - return moment->second; -} - -template T Adam::getSecondMoment(const ValuePtr &v) { - auto moment = _secondMoment.find(v.get()); - if (moment == _secondMoment.end()) { - _secondMoment.insert({v.get(), 0}); - return 0; - } - return moment->second; -} - } // namespace shkyera diff --git a/include/nn/optimizers/SGD.hpp b/include/nn/optimizers/SGD.hpp index 016852f..69ed23a 100644 --- a/include/nn/optimizers/SGD.hpp +++ b/include/nn/optimizers/SGD.hpp @@ -24,9 +24,7 @@ using SGD64 = SGD; template class SGD : public Optimizer { private: T _momentum; - std::unordered_map *, T> _moment; - - T getMoment(const ValuePtr &v); + std::vector _moments; public: SGD(std::vector> params, T learningRate, T momentum = 0.9); @@ -37,27 +35,21 @@ template class SGD : public Optimizer { template SGD::SGD(std::vector> params, T learningRate, T momentum) : Optimizer(params, learningRate) { _momentum = momentum; + _moments.resize(params.size(), 0); } template void SGD::step() { static bool initialized = false; - for (const ValuePtr ¶m : this->_parameters) { + for (size_t i = 0; i < this->_parameters.size(); ++i) { + const ValuePtr ¶m = this->_parameters[i]; + T gradient = param->getGradient(); - T moment = initialized ? _momentum * getMoment(param) + (1 - _momentum) * gradient : gradient; - _moment.insert({param.get(), moment}); + T moment = initialized ? _momentum * _moments[i] + (1 - _momentum) * gradient : gradient; + _moments[i] = moment; param->_data -= this->_learningRate * moment; } } -template T SGD::getMoment(const ValuePtr &v) { - auto moment = _moment.find(v.get()); - if (moment == _moment.end()) { - _moment.insert({v.get(), 0}); - return 0; - } - return moment->second; -} - } // namespace shkyera