diff --git a/examples/xor_classification.cpp b/examples/xor_classification.cpp index 30ff5fa..bb20385 100644 --- a/examples/xor_classification.cpp +++ b/examples/xor_classification.cpp @@ -16,14 +16,16 @@ int main() { auto mlp = SequentialBuilder::begin() .add(Linear32::create(2, 15)) .add(ReLU32::create()) - .add(Dropout32::create(15, 5, 0.2)) + .add(Linear32::create(15, 5)) + .add(Sigmoid32::create()) + .add(Linear32::create(5, 5)) .add(Tanh32::create()) .add(Linear32::create(5, 2)) .add(Softmax32::create()) .build(); // clang-format on - Optimizer32 optimizer = Optimizer(mlp->parameters(), 0.1); + Adam32 optimizer = Adam(mlp->parameters(), 0.01); Loss::Function32 lossFunction = Loss::CrossEntropy; // ------ TRAINING THE NETWORK ------- // diff --git a/include/ShkyeraGrad.hpp b/include/ShkyeraGrad.hpp index 275e810..0d8483b 100644 --- a/include/ShkyeraGrad.hpp +++ b/include/ShkyeraGrad.hpp @@ -14,9 +14,11 @@ #include "nn/Loss.hpp" #include "nn/Module.hpp" #include "nn/Neuron.hpp" -#include "nn/Optimizer.hpp" #include "nn/Sequential.hpp" +#include "nn/optimizers/Adam.hpp" +#include "nn/optimizers/Optimizer.hpp" + #include "nn/activation/Activation.hpp" #include "nn/activation/Exp.hpp" #include "nn/activation/ReLU.hpp" diff --git a/include/core/Value.hpp b/include/core/Value.hpp index c7f9007..c534419 100644 --- a/include/core/Value.hpp +++ b/include/core/Value.hpp @@ -19,6 +19,8 @@ namespace shkyera { template class Optimizer; +template class Adam; + template class Value; template using ValuePtr = std::shared_ptr>; @@ -39,6 +41,7 @@ template class Value : public std::enable_shared_from_this public: friend class Optimizer; + friend class Adam; static ValuePtr create(T data); diff --git a/include/nn/optimizers/Adam.hpp b/include/nn/optimizers/Adam.hpp new file mode 100644 index 0000000..c79986c --- /dev/null +++ b/include/nn/optimizers/Adam.hpp @@ -0,0 +1,88 @@ +/** + * Copyright © 2023 Franciszek Szewczyk. None of the rights reserved. + * This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are + * encouraged to buy the author a beer in return. + * Contact the author at szewczyk.franciszek02@gmail.com for inquiries and support. + */ + +#pragma once + +#include +#include + +#include "../../core/Type.hpp" +#include "../../core/Value.hpp" +#include "../Module.hpp" +#include "Optimizer.hpp" + +namespace shkyera { + +template class Adam; +using Adam32 = Adam; +using Adam64 = Adam; + +template class Adam : public Optimizer { + private: + T _b1; + T _b2; + T _eps; + size_t _timestep; + + std::unordered_map *, T> _firstMoment; + std::unordered_map *, T> _secondMoment; + + T getFirstMoment(const ValuePtr &v); + T getSecondMoment(const ValuePtr &v); + + public: + Adam(std::vector> params, T learningRate, T b1 = 0.9, T b2 = 0.999, T eps = 1e-8); + + void step() override; +}; + +template +Adam::Adam(std::vector> params, T learningRate, T b1, T b2, T eps) : Optimizer(params, learningRate) { + _timestep = 0; + _b1 = b1; + _b2 = b2; + _eps = eps; +} + +template void Adam::step() { + _timestep++; + + for (const ValuePtr ¶m : this->_parameters) { + T gradient = param->getGradient(); + + T firstMoment = _b1 * getFirstMoment(param) + (1 - _b1) * gradient; + T secondMoment = _b2 * getSecondMoment(param) + (1 - _b2) * gradient * gradient; + + _firstMoment.insert({param.get(), firstMoment}); + _secondMoment.insert({param.get(), secondMoment}); + + T firstMomentHat = firstMoment / (1 - pow(_b1, _timestep)); + T secondMomentHat = secondMoment / (1 - pow(_b2, _timestep)); + + param->_data -= (this->_learningRate * firstMomentHat) / (sqrt(secondMomentHat) + _eps); + } +} + +template T Adam::getFirstMoment(const ValuePtr &v) { + auto moment = _firstMoment.find(v.get()); + if (moment == _firstMoment.end()) { + _firstMoment.insert({v.get(), 0}); + return 0; + } + return moment->second; +} + +template T Adam::getSecondMoment(const ValuePtr &v) { + auto moment = _secondMoment.find(v.get()); + if (moment == _secondMoment.end()) { + _secondMoment.insert({v.get(), 0}); + return 0; + } + return moment->second; +} + +} // namespace shkyera diff --git a/include/nn/Optimizer.hpp b/include/nn/optimizers/Optimizer.hpp similarity index 88% rename from include/nn/Optimizer.hpp rename to include/nn/optimizers/Optimizer.hpp index 46eca35..83cb994 100644 --- a/include/nn/Optimizer.hpp +++ b/include/nn/optimizers/Optimizer.hpp @@ -9,9 +9,9 @@ #include -#include "../core/Type.hpp" -#include "../core/Value.hpp" -#include "Module.hpp" +#include "../../core/Type.hpp" +#include "../../core/Value.hpp" +#include "../Module.hpp" namespace shkyera { @@ -19,15 +19,15 @@ using Optimizer32 = Optimizer; using Optimizer64 = Optimizer; template class Optimizer { - private: + protected: std::vector> _parameters; T _learningRate; public: Optimizer(std::vector> params, T learningRate); - void reset(); - void step(); + virtual void reset(); + virtual void step(); }; template