Skip to content

Commit

Permalink
Adam implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
fszewczyk committed Nov 8, 2023
1 parent e30b15e commit 8b9bf5f
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 9 deletions.
6 changes: 4 additions & 2 deletions examples/xor_classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ int main() {
auto mlp = SequentialBuilder<Type::float32>::begin()
.add(Linear32::create(2, 15))
.add(ReLU32::create())
.add(Dropout32::create(15, 5, 0.2))
.add(Linear32::create(15, 5))
.add(Sigmoid32::create())
.add(Linear32::create(5, 5))
.add(Tanh32::create())
.add(Linear32::create(5, 2))
.add(Softmax32::create())
.build();
// clang-format on

Optimizer32 optimizer = Optimizer<Type::float32>(mlp->parameters(), 0.1);
Adam32 optimizer = Adam<Type::float32>(mlp->parameters(), 0.05);
Loss::Function32 lossFunction = Loss::CrossEntropy<Type::float32>;

// ------ TRAINING THE NETWORK ------- //
Expand Down
4 changes: 3 additions & 1 deletion include/ShkyeraGrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#include "nn/Loss.hpp"
#include "nn/Module.hpp"
#include "nn/Neuron.hpp"
#include "nn/Optimizer.hpp"
#include "nn/Sequential.hpp"

#include "nn/optimizers/Adam.hpp"
#include "nn/optimizers/Optimizer.hpp"

#include "nn/activation/Activation.hpp"
#include "nn/activation/Exp.hpp"
#include "nn/activation/ReLU.hpp"
Expand Down
3 changes: 3 additions & 0 deletions include/core/Value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
namespace shkyera {

template <typename T> class Optimizer;
template <typename T> class Adam;

template <typename T> class Value;
template <typename T> using ValuePtr = std::shared_ptr<Value<T>>;

Expand All @@ -39,6 +41,7 @@ template <typename T> class Value : public std::enable_shared_from_this<Value<T>

public:
friend class Optimizer<T>;
friend class Adam<T>;

static ValuePtr<T> create(T data);

Expand Down
85 changes: 85 additions & 0 deletions include/nn/optimizers/Adam.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/**
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved.
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are
* encouraged to buy the author a beer in return.
* Contact the author at [email protected] for inquiries and support.
*/

#pragma once

#include <unordered_map>
#include <vector>

#include "../../core/Type.hpp"
#include "../../core/Value.hpp"
#include "../Module.hpp"
#include "Optimizer.hpp"

namespace shkyera {

template <typename T> class Adam;
using Adam32 = Adam<Type::float32>;
using Adam64 = Adam<Type::float32>;

template <typename T> class Adam : public Optimizer<T> {
private:
T _b1;
T _b2;
T _eps;
size_t _timestep;

std::unordered_map<Value<T> *, T> _firstMoment;
std::unordered_map<Value<T> *, T> _secondMoment;

T getFirstMoment(const ValuePtr<T> &v);
T getSecondMoment(const ValuePtr<T> &v);

public:
Adam(std::vector<ValuePtr<T>> params, T learningRate, T b1 = 0.9, T b2 = 0.999, T eps = 1e-8);

void step() override;
};

template <typename T>
Adam<T>::Adam(std::vector<ValuePtr<T>> params, T learningRate, T b1, T b2, T eps) : Optimizer<T>(params, learningRate) {
_timestep = 0;
_b1 = b1;
_b2 = b2;
_eps = eps;
}

template <typename T> void Adam<T>::step() {
_timestep++;

for (const ValuePtr<T> &param : this->_parameters) {
T gradient = param->getGradient();

T firstMoment = _b1 * getFirstMoment(param) + (1 - _b1) * gradient;
T secondMoment = _b2 * getSecondMoment(param) + (1 - _b2) * gradient * gradient;

T firstMomentHat = firstMoment / (1 - pow(_b1, _timestep));
T secondMomentHat = secondMoment / (1 - pow(_b2, _timestep));

param->_data -= (this->_learningRate * firstMomentHat) / (sqrt(secondMomentHat) + _eps);
}
}

template <typename T> T Adam<T>::getFirstMoment(const ValuePtr<T> &v) {
auto moment = _firstMoment.find(v.get());
if (moment == _firstMoment.end()) {
_firstMoment.insert({v.get(), 0});
return 0;
}
return moment->second;
}

template <typename T> T Adam<T>::getSecondMoment(const ValuePtr<T> &v) {
auto moment = _secondMoment.find(v.get());
if (moment == _secondMoment.end()) {
_secondMoment.insert({v.get(), 0});
return 0;
}
return moment->second;
}

} // namespace shkyera
12 changes: 6 additions & 6 deletions include/nn/Optimizer.hpp → include/nn/optimizers/Optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,25 @@

#include <vector>

#include "../core/Type.hpp"
#include "../core/Value.hpp"
#include "Module.hpp"
#include "../../core/Type.hpp"
#include "../../core/Value.hpp"
#include "../Module.hpp"

namespace shkyera {

using Optimizer32 = Optimizer<Type::float32>;
using Optimizer64 = Optimizer<Type::float32>;

template <typename T> class Optimizer {
private:
protected:
std::vector<ValuePtr<T>> _parameters;
T _learningRate;

public:
Optimizer(std::vector<ValuePtr<T>> params, T learningRate);

void reset();
void step();
virtual void reset();
virtual void step();
};

template <typename T>
Expand Down

0 comments on commit 8b9bf5f

Please sign in to comment.