-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from fszewczyk/cross-entropy-loss
Softmax and Cross Entropy Loss
- Loading branch information
Showing
12 changed files
with
235 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include "../include/ShkyeraGrad.hpp" | ||
|
||
int main() { | ||
using namespace shkyera; | ||
|
||
// clang-format off | ||
std::vector<Vec32> xs; | ||
std::vector<Vec32> ys; | ||
|
||
// ---------- INPUT ----------- | -------- OUTPUT --------- // | ||
xs.push_back(Vec32::of({0, 0})); ys.push_back(Vec32::of({1, 0})); | ||
xs.push_back(Vec32::of({1, 0})); ys.push_back(Vec32::of({0, 1})); | ||
xs.push_back(Vec32::of({0, 1})); ys.push_back(Vec32::of({0, 1})); | ||
xs.push_back(Vec32::of({1, 1})); ys.push_back(Vec32::of({1, 0})); | ||
|
||
auto mlp = SequentialBuilder<Type::float32>::begin() | ||
.add(Linear32::create(2, 15)) | ||
.add(ReLU32::create()) | ||
.add(Dropout32::create(15, 5, 0.2)) | ||
.add(Tanh32::create()) | ||
.add(Linear32::create(5, 2)) | ||
.add(Softmax32::create()) | ||
.build(); | ||
// clang-format on | ||
|
||
Optimizer32 optimizer = Optimizer<Type::float32>(mlp->parameters(), 0.1); | ||
Loss::Function32 lossFunction = Loss::CrossEntropy<Type::float32>; | ||
|
||
// ------ TRAINING THE NETWORK ------- // | ||
for (size_t epoch = 0; epoch < 200; epoch++) { | ||
auto epochLoss = Val32::create(0); | ||
|
||
optimizer.reset(); | ||
for (size_t sample = 0; sample < xs.size(); ++sample) { | ||
Vec32 pred = mlp->forward(xs[sample]); | ||
auto loss = lossFunction(pred, ys[sample]); | ||
|
||
epochLoss = epochLoss + loss; | ||
} | ||
optimizer.step(); | ||
|
||
std::cout << "Epoch: " << epoch + 1 << " Loss: " << epochLoss->getValue() / xs.size() << std::endl; | ||
} | ||
|
||
// ------ VERIFYING THAT IT WORKS ------// | ||
for (size_t sample = 0; sample < xs.size(); ++sample) { | ||
Vec32 pred = mlp->forward(xs[sample]); | ||
std::cout << xs[sample] << " -> " << pred << "\t| True: " << ys[sample] << std::endl; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/** | ||
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved. | ||
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are | ||
* encouraged to buy the author a beer in return. | ||
* Contact the author at [email protected] for inquiries and support. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "Activation.hpp" | ||
|
||
namespace shkyera { | ||
|
||
template <typename T> class Softmax; | ||
using Softmax32 = Softmax<Type::float32>; | ||
using Softmax64 = Softmax<Type::float64>; | ||
|
||
template <typename T> class Softmax : public Activation<T> { | ||
public: | ||
static std::shared_ptr<Softmax<T>> create(); | ||
|
||
virtual Vector<T> operator()(const Vector<T> &x) const override; | ||
}; | ||
|
||
template <typename T> std::shared_ptr<Softmax<T>> Softmax<T>::create() { | ||
return std::shared_ptr<Softmax<T>>(new Softmax<T>()); | ||
} | ||
|
||
template <typename T> Vector<T> Softmax<T>::operator()(const Vector<T> &x) const { | ||
std::vector<ValuePtr<T>> out; | ||
out.reserve(x.size()); | ||
|
||
auto maxValue = Value<T>::create(x[0]->getValue()); | ||
for (size_t i = 1; i < x.size(); ++i) | ||
if (x[i] > maxValue) | ||
maxValue = x[i]; | ||
|
||
auto sumExponentiated = Value<T>::create(0); | ||
for (size_t i = 0; i < x.size(); ++i) { | ||
auto exponentiated = (x[i] - maxValue)->exp(); | ||
out.emplace_back(exponentiated); | ||
sumExponentiated = sumExponentiated + exponentiated; | ||
} | ||
|
||
auto vectorizedOut = Vector<T>(out) / sumExponentiated; | ||
|
||
return vectorizedOut; | ||
} | ||
|
||
} // namespace shkyera |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters