Skip to content

Commit

Permalink
AdaMax implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
fszewczyk committed Nov 9, 2023
1 parent e0373a0 commit 9d9dc64
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/ShkyeraGrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "nn/Neuron.hpp"
#include "nn/Sequential.hpp"

#include "nn/optimizers/AdaMax.hpp"
#include "nn/optimizers/Adam.hpp"
#include "nn/optimizers/NAG.hpp"
#include "nn/optimizers/Optimizer.hpp"
Expand Down
2 changes: 2 additions & 0 deletions include/core/Value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace shkyera {

template <typename T> class Optimizer;
template <typename T> class Adam;
template <typename T> class AdaMax;
template <typename T> class SGD;
template <typename T> class NAG;

Expand All @@ -47,6 +48,7 @@ template <typename T> class Value : public std::enable_shared_from_this<Value<T>
public:
friend class Optimizer<T>;
friend class Adam<T>;
friend class AdaMax<T>;
friend class SGD<T>;
friend class NAG<T>;

Expand Down
69 changes: 69 additions & 0 deletions include/nn/optimizers/AdaMax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved.
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are
* encouraged to buy the author a beer in return.
* Contact the author at [email protected] for inquiries and support.
*/

#pragma once

#include <unordered_map>
#include <vector>

#include "../../core/Type.hpp"
#include "../../core/Value.hpp"
#include "../Module.hpp"
#include "Optimizer.hpp"

namespace shkyera {

template <typename T> class AdaMax;
using AdaMax32 = AdaMax<Type::float32>;
using AdaMax64 = AdaMax<Type::float32>;

template <typename T> class AdaMax : public Optimizer<T> {
private:
size_t _timestep;
T _b1;
T _b2;
T _eps;

std::vector<T> _moments;
std::vector<T> _infinityNorms;

public:
AdaMax(std::vector<ValuePtr<T>> params, T learningRate, T b1 = 0.9, T b2 = 0.999, T eps = 1e-8);

void step() override;
};

template <typename T>
AdaMax<T>::AdaMax(std::vector<ValuePtr<T>> params, T learningRate, T b1, T b2, T eps)
: Optimizer<T>(params, learningRate) {
_b1 = b1;
_b2 = b2;
_eps = eps;

_timestep = 0;
_moments.resize(params.size(), 0);
_infinityNorms.resize(params.size(), 0);
}

template <typename T> void AdaMax<T>::step() {
++_timestep;

for (size_t i = 0; i < this->_parameters.size(); ++i) {
const ValuePtr<T> &param = this->_parameters[i];

T gradient = param->getGradient();
T moment = _b1 * _moments[i] + (1 - _b1) * gradient;
T infinityNorm = std::max(_b2 * _infinityNorms[i], std::abs(gradient) + _eps);

param->_data -= (this->_learningRate / (1 - std::pow(_b1, _timestep))) * (moment / infinityNorm);

_infinityNorms[i] = infinityNorm;
_moments[i] = moment;
}
}

} // namespace shkyera

0 comments on commit 9d9dc64

Please sign in to comment.