Skip to content

Commit

Permalink
Softmax and Cross Entropy Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
fszewczyk committed Nov 8, 2023
1 parent 21c6661 commit 1f49913
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 7 deletions.
50 changes: 50 additions & 0 deletions examples/xor_classification.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "../include/ShkyeraGrad.hpp"

int main() {
using namespace shkyera;

// clang-format off
std::vector<Vec32> xs;
std::vector<Vec32> ys;

// ---------- INPUT ----------- | -------- OUTPUT --------- //
xs.push_back(Vec32::of({0, 0})); ys.push_back(Vec32::of({1, 0}));
xs.push_back(Vec32::of({1, 0})); ys.push_back(Vec32::of({0, 1}));
xs.push_back(Vec32::of({0, 1})); ys.push_back(Vec32::of({0, 1}));
xs.push_back(Vec32::of({1, 1})); ys.push_back(Vec32::of({1, 0}));

auto mlp = SequentialBuilder<Type::float32>::begin()
.add(Linear32::create(2, 15))
.add(ReLU32::create())
.add(Dropout32::create(15, 5, 0.2))
.add(Tanh32::create())
.add(Linear32::create(5, 2))
.add(Softmax32::create())
.build();
// clang-format on

Optimizer32 optimizer = Optimizer<Type::float32>(mlp->parameters(), 0.1);
Loss::Function32 lossFunction = Loss::CrossEntropy<Type::float32>;

// ------ TRAINING THE NETWORK ------- //
for (size_t epoch = 0; epoch < 200; epoch++) {
auto epochLoss = Val32::create(0);

optimizer.reset();
for (size_t sample = 0; sample < xs.size(); ++sample) {
Vec32 pred = mlp->forward(xs[sample]);
auto loss = lossFunction(pred, ys[sample]);

epochLoss = epochLoss + loss;
}
optimizer.step();

std::cout << "Epoch: " << epoch + 1 << " Loss: " << epochLoss->getValue() / xs.size() << std::endl;
}

// ------ VERIFYING THAT IT WORKS ------//
for (size_t sample = 0; sample < xs.size(); ++sample) {
Vec32 pred = mlp->forward(xs[sample]);
std::cout << xs[sample] << " -> " << pred << "\t| True: " << ys[sample] << std::endl;
}
}
File renamed without changes.
1 change: 1 addition & 0 deletions include/ShkyeraGrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "nn/activation/Exp.hpp"
#include "nn/activation/ReLU.hpp"
#include "nn/activation/Sigmoid.hpp"
#include "nn/activation/Softmax.hpp"
#include "nn/activation/Tanh.hpp"

#include "nn/layers/Dropout.hpp"
Expand Down
13 changes: 12 additions & 1 deletion include/core/Value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ template <typename T> class Value : public std::enable_shared_from_this<Value<T>
ValuePtr<T> relu();
ValuePtr<T> sigmoid();
ValuePtr<T> exp();
ValuePtr<T> log();
ValuePtr<T> pow(ValuePtr<T> exponent);

template <typename U> friend ValuePtr<U> operator+(ValuePtr<U> a, ValuePtr<U> b);
Expand Down Expand Up @@ -157,6 +158,16 @@ template <typename T> ValuePtr<T> Value<T>::exp() {
return result;
}

template <typename T> ValuePtr<T> Value<T>::log() {
auto thisValue = this->shared_from_this();

ValuePtr<T> result = Value<T>::create(std::log(_data));
result->_children = {thisValue};
result->_backward = [thisValue, result]() { thisValue->_gradient += (1 / thisValue->_data) * result->_gradient; };

return result;
}

template <typename T> ValuePtr<T> Value<T>::pow(ValuePtr<T> exponent) {
auto thisValue = this->shared_from_this();

Expand All @@ -165,7 +176,7 @@ template <typename T> ValuePtr<T> Value<T>::pow(ValuePtr<T> exponent) {
result->_backward = [thisValue, exponent, result]() {
thisValue->_gradient += (exponent->_data * std::pow(thisValue->_data, exponent->_data - 1)) * result->_gradient;
exponent->_gradient +=
(std::pow(thisValue->_data, exponent->_data) * log(thisValue->_data)) * result->_gradient;
(std::pow(thisValue->_data, exponent->_data) * std::log(thisValue->_data)) * result->_gradient;
};

return result;
Expand Down
73 changes: 70 additions & 3 deletions include/core/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,24 @@ template <typename T> class Vector {
public:
Vector() = default;
Vector(std::vector<ValuePtr<T>> values);
static Vector<T> of(const std::vector<T> &values);

static Vector<T> of(const std::vector<T> &values);
ValuePtr<T> dot(const Vector<T> &other) const;
ValuePtr<T> operator[](size_t index) const;

ValuePtr<T> sum() const;
size_t size() const;

template <typename U> friend std::ostream &operator<<(std::ostream &os, const Vector<U> &vector);

template <typename U> friend Vector<U> operator/(Vector<U> x, U val);
template <typename U> friend Vector<U> operator*(Vector<U> x, U val);
template <typename U> friend Vector<U> operator/(Vector<U> x, ValuePtr<U> val);
template <typename U> friend Vector<U> operator*(Vector<U> x, ValuePtr<U> val);
Vector<T> &operator/=(T val);
Vector<T> &operator*=(T val);
Vector<T> &operator/=(ValuePtr<T> val);
Vector<T> &operator*=(ValuePtr<T> val);

ValuePtr<T> operator[](size_t index) const;
};

template <typename T> Vector<T>::Vector(std::vector<ValuePtr<T>> values) { _values = values; }
Expand Down Expand Up @@ -62,6 +72,63 @@ template <typename T> ValuePtr<T> Vector<T>::dot(const Vector<T> &other) const {
return result;
}

template <typename T> ValuePtr<T> Vector<T>::sum() const {
auto sum = Value<T>::create(0);
for (const auto &entry : _values)
sum = sum + entry;
return sum;
}

template <typename T> Vector<T> operator/(Vector<T> x, T val) {
x /= val;
return x;
}

template <typename T> Vector<T> operator*(Vector<T> x, T val) {
x *= val;
return x;
}

template <typename T> Vector<T> operator/(Vector<T> x, ValuePtr<T> val) {
auto out = x;
for (size_t i = 0; i < out._values.size(); ++i)
out._values[i] = out._values[i] / val;
return out;
}

template <typename T> Vector<T> operator*(Vector<T> x, ValuePtr<T> val) {
auto out = x;
for (size_t i = 0; i < out._values.size(); ++i)
out._values[i] = out._values[i] * val;
return out;
}

template <typename T> Vector<T> &Vector<T>::operator/=(T val) {
auto divisor = Value<T>::create(val);
for (size_t i = 0; i < _values.size(); ++i)
_values[i] = _values[i] / divisor;
return *this;
}

template <typename T> Vector<T> &Vector<T>::operator*=(T val) {
auto divisor = Value<T>::create(val);
for (size_t i = 0; i < _values.size(); ++i)
_values[i] = _values[i] * divisor;
return *this;
}

template <typename T> Vector<T> &Vector<T>::operator/=(ValuePtr<T> val) {
for (size_t i = 0; i < _values.size(); ++i)
_values[i] = _values[i] / val;
return *this;
}

template <typename T> Vector<T> &Vector<T>::operator*=(ValuePtr<T> val) {
for (size_t i = 0; i < _values.size(); ++i)
_values[i] = _values[i] * val;
return *this;
}

template <typename T> ValuePtr<T> Vector<T>::operator[](size_t index) const { return _values[index]; }

template <typename T> std::ostream &operator<<(std::ostream &os, const Vector<T> &vector) {
Expand Down
28 changes: 28 additions & 0 deletions include/nn/Loss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,32 @@ Function<T> MAE = [](Vector<T> a, Vector<T> b) {
return loss;
};

template <typename T>
Function<T> CrossEntropy = [](Vector<T> a, Vector<T> b) {
if (a.size() != b.size()) {
throw std::invalid_argument(
"Vectors need to be of the same size to compute the Cross Entropy loss. Sizes are " +
std::to_string(a.size()) + " and " + std::to_string(b.size()) + ".");
}

auto aSum = a.sum();
auto bSum = b.sum();

if (aSum->getValue() < 0.99 || aSum->getValue() > 1.01 || aSum->getValue() < 0.99 || aSum->getValue() > 1.01) {
throw std::invalid_argument("To compute Cross Entropy Loss, both elements of each vector need to sum to 1(+/- "
"0.01). Currently, they sum to:" +
std::to_string(aSum->getValue()) + " and " + std::to_string(bSum->getValue()) +
".");
}

auto loss = Value<T>::create(0);
for (size_t i = 0; i < a.size(); ++i) {
loss = loss - (b[i] * (a[i]->log()));
}

loss->backward();

return loss;
};

} // namespace shkyera::Loss
50 changes: 50 additions & 0 deletions include/nn/activation/Softmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved.
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are
* encouraged to buy the author a beer in return.
* Contact the author at [email protected] for inquiries and support.
*/

#pragma once

#include "Activation.hpp"

namespace shkyera {

template <typename T> class Softmax;
using Softmax32 = Softmax<Type::float32>;
using Softmax64 = Softmax<Type::float64>;

template <typename T> class Softmax : public Activation<T> {
public:
static std::shared_ptr<Softmax<T>> create();

virtual Vector<T> operator()(const Vector<T> &x) const override;
};

template <typename T> std::shared_ptr<Softmax<T>> Softmax<T>::create() {
return std::shared_ptr<Softmax<T>>(new Softmax<T>());
}

template <typename T> Vector<T> Softmax<T>::operator()(const Vector<T> &x) const {
std::vector<ValuePtr<T>> out;
out.reserve(x.size());

auto maxValue = Value<T>::create(x[0]->getValue());
for (size_t i = 1; i < x.size(); ++i)
if (x[i] > maxValue)
maxValue = x[i];

auto sumExponentiated = Value<T>::create(0);
for (size_t i = 0; i < x.size(); ++i) {
auto exponentiated = (x[i] - maxValue)->exp();
out.emplace_back(exponentiated);
sumExponentiated = sumExponentiated + exponentiated;
}

auto vectorizedOut = Vector<T>(out) / sumExponentiated;

return vectorizedOut;
}

} // namespace shkyera
7 changes: 4 additions & 3 deletions include/nn/layers/Dropout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@ template <typename T> DropoutPtr<T> Dropout<T>::create(size_t input, size_t size
template <typename T> Vector<T> Dropout<T>::operator()(const Vector<T> &x) const {
std::vector<ValuePtr<T>> alteredInput;
alteredInput.reserve(x.size());
auto scaling = Value<T>::create(1.0 / (1 - _dropout));
for (size_t i = 0; i < x.size(); ++i)
alteredInput.push_back(x[i] * scaling);
alteredInput.push_back(x[i]);

std::vector<size_t> indicesToRemove = utils::sample<size_t>(0, x.size() - 1, _dropout * x.size(), false);
for (size_t idxToRemove : indicesToRemove)
alteredInput[idxToRemove] = Value<T>::create(0);

return Linear<T>::operator()(Vector<T>(alteredInput));
auto transformedInput = Vector<T>(alteredInput) * static_cast<T>(1.0 / (1 - _dropout));

return Linear<T>::operator()(transformedInput);
}

} // namespace shkyera

0 comments on commit 1f49913

Please sign in to comment.