From c1a8e073b3f8ee8482f676e8e95c274e90aae397 Mon Sep 17 00:00:00 2001 From: "szewczyk.franciszek02" Date: Thu, 9 Nov 2023 16:41:29 +0100 Subject: [PATCH 1/5] Simple mini-batching --- examples/xor_classification.cpp | 12 ++++-------- examples/xor_regression.cpp | 12 ++++-------- include/core/Vector.hpp | 1 + include/nn/Loss.hpp | 17 ++++++++++++----- include/nn/Module.hpp | 11 ++++++++++- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/examples/xor_classification.cpp b/examples/xor_classification.cpp index 49abd45..53053ce 100644 --- a/examples/xor_classification.cpp +++ b/examples/xor_classification.cpp @@ -30,18 +30,14 @@ int main() { // ------ TRAINING THE NETWORK ------- // for (size_t epoch = 0; epoch < 100; epoch++) { - auto epochLoss = Val32::create(0); - optimizer.reset(); - for (size_t sample = 0; sample < xs.size(); ++sample) { - Vec32 pred = mlp->forward(xs[sample]); - auto loss = lossFunction(pred, ys[sample]); - epochLoss = epochLoss + loss; - } + std::vector pred = mlp->forward(xs); + auto loss = Loss::compute32(lossFunction, pred, ys); + optimizer.step(); - std::cout << "Epoch: " << epoch + 1 << " Loss: " << epochLoss->getValue() / xs.size() << std::endl; + std::cout << "Epoch: " << epoch + 1 << " Loss: " << loss->getValue() / xs.size() << std::endl; } // ------ VERIFYING THAT IT WORKS ------// diff --git a/examples/xor_regression.cpp b/examples/xor_regression.cpp index 61d6630..cdaf186 100644 --- a/examples/xor_regression.cpp +++ b/examples/xor_regression.cpp @@ -28,18 +28,14 @@ int main() { auto lossFunction = Loss::MSE; for (size_t epoch = 0; epoch < 100; epoch++) { // We train for 100 epochs - auto epochLoss = Val32::create(0); + optimizer.reset(); // Reset the gradients - optimizer.reset(); // Reset the gradients - for (size_t sample = 0; sample < xs.size(); ++sample) { // We go through each sample - Vec32 pred = network->forward(xs[sample]); // We get some prediction - auto loss = lossFunction(pred, ys[sample]); // And calculate its error + auto pred = network->forward(xs); // We get some prediction + auto loss = Loss::compute(lossFunction, pred, ys); // And calculate its error - epochLoss = epochLoss + loss; // Store the loss for feedback - } optimizer.step(); // Update the parameters - auto averageLoss = epochLoss / Val32::create(xs.size()); + auto averageLoss = loss / Val32::create(xs.size()); std::cout << "Epoch: " << epoch + 1 << " Loss: " << averageLoss->getValue() << std::endl; } diff --git a/include/core/Vector.hpp b/include/core/Vector.hpp index 9b0b7bf..493a4bc 100644 --- a/include/core/Vector.hpp +++ b/include/core/Vector.hpp @@ -18,6 +18,7 @@ template class Vector; using Vec32 = Vector; using Vec64 = Vector; +template using Batch = std::vector>; template class Vector { private: diff --git a/include/nn/Loss.hpp b/include/nn/Loss.hpp index 9d598ae..e8bd8b3 100644 --- a/include/nn/Loss.hpp +++ b/include/nn/Loss.hpp @@ -8,6 +8,7 @@ #pragma once #include "../core/Value.hpp" +#include "../core/Vector.hpp" namespace shkyera::Loss { @@ -31,8 +32,6 @@ Function MSE = [](Vector a, Vector b) { if (a.size() > 0) loss = loss / Value::create(a.size()); - loss->backward(); - return loss; }; @@ -52,8 +51,6 @@ Function MAE = [](Vector a, Vector b) { if (a.size() > 0) loss = loss / Value::create(a.size()); - loss->backward(); - return loss; }; @@ -80,9 +77,19 @@ Function CrossEntropy = [](Vector a, Vector b) { loss = loss - (b[i] * (a[i]->log())); } + return loss; +}; + +template ValuePtr compute(Function lossFunction, const Batch prediction, const Batch target) { + ValuePtr loss = Value::create(0); + for (size_t i = 0; i < prediction.size(); ++i) { + loss = loss + lossFunction(prediction[i], target[i]); + } + loss = loss / Value::create(prediction.size()); + loss->backward(); return loss; -}; +} } // namespace shkyera::Loss diff --git a/include/nn/Module.hpp b/include/nn/Module.hpp index 23373cd..8aa1020 100644 --- a/include/nn/Module.hpp +++ b/include/nn/Module.hpp @@ -19,8 +19,17 @@ template class Module { Module() = default; public: - Vector forward(const Vector &x) const { return (*this)(x); } + template U forward(const U &x) const { return (*this)(x); } + virtual Vector operator()(const Vector &x) const { return x; } + std::vector> operator()(const std::vector> &x) const { + std::vector> out(x.size()); + for (size_t i = 0; i < x.size(); ++i) { + out[i] = this->operator()(x[i]); + } + return out; + } + virtual std::vector> parameters() const { return {}; } }; From a40b76d3fb54354eed1e35147cb891e6a5a4a89b Mon Sep 17 00:00:00 2001 From: "szewczyk.franciszek02" Date: Thu, 9 Nov 2023 20:09:31 +0100 Subject: [PATCH 2/5] Simple Dataset --- examples/xor_classification.cpp | 48 --------------------------------- examples/xor_regression.cpp | 30 ++++++++++----------- include/ShkyeraGrad.hpp | 2 ++ include/nn/Loss.hpp | 7 +++++ 4 files changed, 24 insertions(+), 63 deletions(-) delete mode 100644 examples/xor_classification.cpp diff --git a/examples/xor_classification.cpp b/examples/xor_classification.cpp deleted file mode 100644 index 53053ce..0000000 --- a/examples/xor_classification.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include "../include/ShkyeraGrad.hpp" - -int main() { - using namespace shkyera; - - // clang-format off - std::vector xs; - std::vector ys; - - // ---------- INPUT ----------- | -------- OUTPUT --------- // - xs.push_back(Vec32::of(0, 0)); ys.push_back(Vec32::of(1, 0)); - xs.push_back(Vec32::of(1, 0)); ys.push_back(Vec32::of(0, 1)); - xs.push_back(Vec32::of(0, 1)); ys.push_back(Vec32::of(0, 1)); - xs.push_back(Vec32::of(1, 1)); ys.push_back(Vec32::of(1, 0)); - - auto mlp = SequentialBuilder::begin() - .add(Linear32::create(2, 15)) - .add(ReLU32::create()) - .add(Linear32::create(15, 5)) - .add(Sigmoid32::create()) - .add(Linear32::create(5, 5)) - .add(Tanh32::create()) - .add(Linear32::create(5, 2)) - .add(Softmax32::create()) - .build(); - // clang-format on - - Adam32 optimizer = Adam(mlp->parameters(), 0.1); - Loss::Function32 lossFunction = Loss::CrossEntropy; - - // ------ TRAINING THE NETWORK ------- // - for (size_t epoch = 0; epoch < 100; epoch++) { - optimizer.reset(); - - std::vector pred = mlp->forward(xs); - auto loss = Loss::compute32(lossFunction, pred, ys); - - optimizer.step(); - - std::cout << "Epoch: " << epoch + 1 << " Loss: " << loss->getValue() / xs.size() << std::endl; - } - - // ------ VERIFYING THAT IT WORKS ------// - for (size_t sample = 0; sample < xs.size(); ++sample) { - Vec32 pred = mlp->forward(xs[sample]); - std::cout << xs[sample] << " -> " << pred << "\t| True: " << ys[sample] << std::endl; - } -} diff --git a/examples/xor_regression.cpp b/examples/xor_regression.cpp index cdaf186..d3e77d1 100644 --- a/examples/xor_regression.cpp +++ b/examples/xor_regression.cpp @@ -5,14 +5,12 @@ int main() { using T = Type::float32; // clang-format off - std::vector xs; - std::vector ys; + Dataset data; - // ---------- INPUT ----------- | -------- OUTPUT --------- // - xs.push_back(Vec32::of(0, 0)); ys.push_back(Vec32::of(0)); - xs.push_back(Vec32::of(1, 0)); ys.push_back(Vec32::of(1)); - xs.push_back(Vec32::of(0, 1)); ys.push_back(Vec32::of(1)); - xs.push_back(Vec32::of(1, 1)); ys.push_back(Vec32::of(0)); + data.addSample(Vec32::of(0, 0), Vec32::of(0)); + data.addSample(Vec32::of(0, 1), Vec32::of(1)); + data.addSample(Vec32::of(1, 0), Vec32::of(1)); + data.addSample(Vec32::of(1, 1), Vec32::of(0)); auto network = SequentialBuilder::begin() .add(Linear32::create(2, 15)) @@ -28,19 +26,21 @@ int main() { auto lossFunction = Loss::MSE; for (size_t epoch = 0; epoch < 100; epoch++) { // We train for 100 epochs - optimizer.reset(); // Reset the gradients - - auto pred = network->forward(xs); // We get some prediction - auto loss = Loss::compute(lossFunction, pred, ys); // And calculate its error + auto epochLoss = Val32::create(0); + optimizer.reset(); // Reset the gradients + for (auto &[x, y] : data) { + auto pred = network->forward(x); // We get some prediction + epochLoss = epochLoss + Loss::compute(lossFunction, pred, y); // And calculate its error + } optimizer.step(); // Update the parameters - auto averageLoss = loss / Val32::create(xs.size()); + auto averageLoss = epochLoss / Val32::create(data.size()); std::cout << "Epoch: " << epoch + 1 << " Loss: " << averageLoss->getValue() << std::endl; } - for (size_t sample = 0; sample < xs.size(); ++sample) { // Go through each example - Vec32 pred = network->forward(xs[sample]); // Predict result - std::cout << xs[sample] << " -> " << pred[0] << "\t| True: " << ys[sample][0] << std::endl; + for (auto &[x, y] : data) { // Go through each example + auto pred = network->forward(x); // We get some prediction + std::cout << x << " -> " << pred[0] << "\t| True: " << y[0] << std::endl; } } diff --git a/include/ShkyeraGrad.hpp b/include/ShkyeraGrad.hpp index cfa0138..8df151e 100644 --- a/include/ShkyeraGrad.hpp +++ b/include/ShkyeraGrad.hpp @@ -17,6 +17,8 @@ #include "nn/Neuron.hpp" #include "nn/Sequential.hpp" +#include "nn/data/Dataset.hpp" + #include "nn/optimizers/AdaMax.hpp" #include "nn/optimizers/Adam.hpp" #include "nn/optimizers/NAG.hpp" diff --git a/include/nn/Loss.hpp b/include/nn/Loss.hpp index e8bd8b3..735773a 100644 --- a/include/nn/Loss.hpp +++ b/include/nn/Loss.hpp @@ -80,6 +80,13 @@ Function CrossEntropy = [](Vector a, Vector b) { return loss; }; +template +ValuePtr compute(Function lossFunction, const Vector prediction, const Vector target) { + auto loss = lossFunction(prediction, target); + loss->backward(); + return loss; +} + template ValuePtr compute(Function lossFunction, const Batch prediction, const Batch target) { ValuePtr loss = Value::create(0); for (size_t i = 0; i < prediction.size(); ++i) { From d4debfbbeaf9f0659751960782c58a0a2ad56c91 Mon Sep 17 00:00:00 2001 From: "szewczyk.franciszek02" Date: Thu, 9 Nov 2023 21:10:29 +0100 Subject: [PATCH 3/5] Updated example and README --- README.md | 40 ++++++++++++++++++------------------- examples/xor_regression.cpp | 15 +++++++++----- include/ShkyeraGrad.hpp | 1 + include/core/Utils.hpp | 2 ++ 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 6a8583a..b24e7ad 100644 --- a/README.md +++ b/README.md @@ -34,14 +34,17 @@ int main() { using namespace shkyera; using T = Type::float32; - std::vector xs; - std::vector ys; - - // ---------- INPUT ----------- | -------- OUTPUT --------- // - xs.push_back(Vec32::of(0, 0)); ys.push_back(Vec32::of(0)); - xs.push_back(Vec32::of(1, 0)); ys.push_back(Vec32::of(1)); - xs.push_back(Vec32::of(0, 1)); ys.push_back(Vec32::of(1)); - xs.push_back(Vec32::of(1, 1)); ys.push_back(Vec32::of(0)); + // This is our XOR dataset. It maps from Vec32 to Vec32 + Dataset data; + data.addSample(Vec32::of(0, 0), Vec32::of(0)); + data.addSample(Vec32::of(0, 1), Vec32::of(1)); + data.addSample(Vec32::of(1, 0), Vec32::of(1)); + data.addSample(Vec32::of(1, 1), Vec32::of(0)); + + // The is the data loader, it will take care of batching + size_t batchSize = 2; + bool shuffle = true; + DataLoader loader(data, batchSize, shuffle); auto network = SequentialBuilder::begin() .add(Linear32::create(2, 15)) @@ -52,29 +55,26 @@ int main() { .add(Sigmoid32::create()) .build(); - - auto optimizer = Adam32(network->parameters(), 0.05); + auto optimizer = Adam32(network->parameters(), 0.1); auto lossFunction = Loss::MSE; for (size_t epoch = 0; epoch < 100; epoch++) { // We train for 100 epochs auto epochLoss = Val32::create(0); - optimizer.reset(); // Reset the gradients - for (size_t sample = 0; sample < xs.size(); ++sample) { // We go through each sample - Vec32 pred = network->forward(xs[sample]); // We get some prediction - auto loss = lossFunction(pred, ys[sample]); // And calculate its error - - epochLoss = epochLoss + loss; // Store the loss for feedback + optimizer.reset(); // Reset the gradients + for (const auto &[x, y] : loader) { // For each batch + auto pred = network->forward(x); // We get some prediction + epochLoss = epochLoss + Loss::compute(lossFunction, pred, y); // And calculate its error } optimizer.step(); // Update the parameters - auto averageLoss = epochLoss / Val32::create(xs.size()); + auto averageLoss = epochLoss / Val32::create(loader.getTotalBatches()); std::cout << "Epoch: " << epoch + 1 << " Loss: " << averageLoss->getValue() << std::endl; } - for (size_t sample = 0; sample < xs.size(); ++sample) { // Go through each example - Vec32 pred = network->forward(xs[sample]); // Predict result - std::cout << xs[sample] << " -> " << pred[0] << "\t| True: " << ys[sample][0] << std::endl; + for (auto &[x, y] : data) { // Go through each example + auto pred = network->forward(x); // We get some prediction + std::cout << x << " -> " << pred[0] << "\t| True: " << y[0] << std::endl; } } ``` diff --git a/examples/xor_regression.cpp b/examples/xor_regression.cpp index d3e77d1..7b1fe2b 100644 --- a/examples/xor_regression.cpp +++ b/examples/xor_regression.cpp @@ -5,13 +5,18 @@ int main() { using T = Type::float32; // clang-format off + // This is our XOR dataset. It maps from Vec32 to Vec32 Dataset data; - data.addSample(Vec32::of(0, 0), Vec32::of(0)); data.addSample(Vec32::of(0, 1), Vec32::of(1)); data.addSample(Vec32::of(1, 0), Vec32::of(1)); data.addSample(Vec32::of(1, 1), Vec32::of(0)); + // The is the data loader, it will take care of batching + size_t batchSize = 2; + bool shuffle = true; + DataLoader loader(data, batchSize, shuffle); + auto network = SequentialBuilder::begin() .add(Linear32::create(2, 15)) .add(ReLU32::create()) @@ -22,20 +27,20 @@ int main() { .build(); // clang-format on - auto optimizer = Adam32(network->parameters(), 0.05); + auto optimizer = Adam32(network->parameters(), 0.1); auto lossFunction = Loss::MSE; for (size_t epoch = 0; epoch < 100; epoch++) { // We train for 100 epochs auto epochLoss = Val32::create(0); - optimizer.reset(); // Reset the gradients - for (auto &[x, y] : data) { + optimizer.reset(); // Reset the gradients + for (const auto &[x, y] : loader) { // For each batch auto pred = network->forward(x); // We get some prediction epochLoss = epochLoss + Loss::compute(lossFunction, pred, y); // And calculate its error } optimizer.step(); // Update the parameters - auto averageLoss = epochLoss / Val32::create(data.size()); + auto averageLoss = epochLoss / Val32::create(loader.getTotalBatches()); std::cout << "Epoch: " << epoch + 1 << " Loss: " << averageLoss->getValue() << std::endl; } diff --git a/include/ShkyeraGrad.hpp b/include/ShkyeraGrad.hpp index 8df151e..2752afe 100644 --- a/include/ShkyeraGrad.hpp +++ b/include/ShkyeraGrad.hpp @@ -17,6 +17,7 @@ #include "nn/Neuron.hpp" #include "nn/Sequential.hpp" +#include "nn/data/DataLoader.hpp" #include "nn/data/Dataset.hpp" #include "nn/optimizers/AdaMax.hpp" diff --git a/include/core/Utils.hpp b/include/core/Utils.hpp index aeaa34d..d0ceaed 100644 --- a/include/core/Utils.hpp +++ b/include/core/Utils.hpp @@ -53,6 +53,8 @@ std::enable_if_t, std::vector> sample(T from, T to, siz return sampled; } +template void shuffle(std::vector &vec) { std::shuffle(vec.begin(), vec.end(), rand_dev); } + template auto startTimer() { return Clock::now(); } template From 0e641fff8f8f51e5f92e070a411c367950d89c38 Mon Sep 17 00:00:00 2001 From: "szewczyk.franciszek02" Date: Thu, 9 Nov 2023 21:13:55 +0100 Subject: [PATCH 4/5] Fixed workflows --- .github/workflows/linux.yml | 2 -- .github/workflows/macos.yml | 2 -- .github/workflows/windows.yml | 1 - 3 files changed, 5 deletions(-) diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 10b4976..a80dc65 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -32,7 +32,5 @@ jobs: run: | g++ examples/scalars.cpp -O3 --std=c++17 ./a.out - g++ examples/xor_classification.cpp -O3 --std=c++17 - ./a.out g++ examples/xor_regression.cpp -O3 --std=c++17 ./a.out diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 7d8c246..29110d0 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -32,7 +32,5 @@ jobs: run: | g++ examples/scalars.cpp -O3 --std=c++17 ./a.out - g++ examples/xor_classification.cpp -O3 --std=c++17 - ./a.out g++ examples/xor_regression.cpp -O3 --std=c++17 ./a.out diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 260f5ac..f09efd6 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -30,5 +30,4 @@ jobs: CXX: ${{matrix.conf.compiler}} run: | g++ -o out examples/scalars.cpp -O3 --std=c++17 - g++ -o out examples/xor_classification.cpp -O3 --std=c++17 g++ -o out examples/xor_regression.cpp -O3 --std=c++17 From d8ed9ec31fd837b1b788326270852e4fad0a203c Mon Sep 17 00:00:00 2001 From: "szewczyk.franciszek02" Date: Thu, 9 Nov 2023 21:14:56 +0100 Subject: [PATCH 5/5] DataSet and DataLoader --- .gitignore | 1 - include/nn/data/DataLoader.hpp | 101 +++++++++++++++++++++++++++++++++ include/nn/data/Dataset.hpp | 97 +++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 include/nn/data/DataLoader.hpp create mode 100644 include/nn/data/Dataset.hpp diff --git a/.gitignore b/.gitignore index 4f66872..517b87e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ build/ -data/ docs/html docs/latex diff --git a/include/nn/data/DataLoader.hpp b/include/nn/data/DataLoader.hpp new file mode 100644 index 0000000..18a9ae3 --- /dev/null +++ b/include/nn/data/DataLoader.hpp @@ -0,0 +1,101 @@ +/** + * Copyright © 2023 Franciszek Szewczyk. None of the rights reserved. + * This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are + * encouraged to buy the author a beer in return. + * Contact the author at szewczyk.franciszek02@gmail.com for inquiries and support. + */ + +#pragma once + +#include "Dataset.hpp" + +namespace shkyera { + +template class DataLoader { + private: + const Dataset &_dataset; + size_t _batchSize; + bool _shuffle; + + public: + DataLoader(const Dataset &dataset, size_t batchSize = 4, bool shuffle = false); + + size_t getTotalBatches() const; + + class ConstIterator { + private: + std::vector _order; + size_t _index; + + const DataLoader &_dataLoader; + + public: + ConstIterator(size_t index, const DataLoader &dataLoader); + + std::pair, std::vector> operator*(); + ConstIterator &operator++(); + bool operator!=(const ConstIterator &other); + }; + + ConstIterator begin() const; + ConstIterator end() const; +}; + +template +DataLoader::DataLoader(const Dataset &dataset, size_t batchSize, bool shuffle) + : _dataset(dataset), _batchSize(batchSize), _shuffle(shuffle) {} + +template size_t DataLoader::getTotalBatches() const { + size_t batches = _dataset.size() / _batchSize; + if (_dataset.size() % _batchSize != 0) + batches++; + return batches; +} + +template +DataLoader::ConstIterator::ConstIterator(size_t index, const DataLoader &dataLoader) + : _index(index), _dataLoader(dataLoader) { + _order.resize(_dataLoader._dataset.size(), 0); + + std::iota(_order.begin(), _order.end(), 0); + if (_dataLoader._shuffle) + utils::shuffle(_order); +} + +template +std::pair, std::vector> DataLoader::ConstIterator::operator*() { + size_t beginIndex = _index; + size_t endIndex = std::min(_index + _dataLoader._batchSize, _dataLoader._dataset.size()); + + std::vector inputs(endIndex - beginIndex); + std::vector outputs(endIndex - beginIndex); + + for (size_t i = beginIndex; i < endIndex; ++i) { + auto [in, out] = _dataLoader._dataset[_order[i]]; + inputs[i - beginIndex] = in; + outputs[i - beginIndex] = out; + } + + return {inputs, outputs}; +} + +template +typename DataLoader::ConstIterator &DataLoader::ConstIterator::operator++() { + _index += _dataLoader._batchSize; + _index = std::min(_index, _dataLoader._dataset.size()); + return *this; +} + +template bool DataLoader::ConstIterator::operator!=(const ConstIterator &other) { + return _index != other._index; +} + +template typename DataLoader::ConstIterator DataLoader::begin() const { + return ConstIterator(0, *this); +} + +template typename DataLoader::ConstIterator DataLoader::end() const { + return ConstIterator(_dataset.size(), *this); +} + +} // namespace shkyera diff --git a/include/nn/data/Dataset.hpp b/include/nn/data/Dataset.hpp new file mode 100644 index 0000000..1739540 --- /dev/null +++ b/include/nn/data/Dataset.hpp @@ -0,0 +1,97 @@ +/** + * Copyright © 2023 Franciszek Szewczyk. None of the rights reserved. + * This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are + * encouraged to buy the author a beer in return. + * Contact the author at szewczyk.franciszek02@gmail.com for inquiries and support. + */ + +#pragma once + +#include +#include + +namespace shkyera { + +template class Dataset { + private: + std::vector _inputs; + std::vector _outputs; + + public: + Dataset() = default; + Dataset(const std::vector &inputs, const std::vector &outputs); + + void addSample(T input, U output); + + size_t size() const; + + std::pair operator[](size_t index) const; + + class ConstIterator { + private: + size_t _index; + const Dataset &_dataset; + + public: + ConstIterator(size_t index, const Dataset &dataset); + + const std::pair operator*(); + ConstIterator &operator++(); + bool operator!=(const ConstIterator &other); + }; + + ConstIterator begin() const; + ConstIterator end() const; +}; + +template Dataset::Dataset(const std::vector &inputs, const std::vector &outputs) { + if (inputs.size() != outputs.size()) + throw std::invalid_argument( + "To create a dataset, you have to pass the same amount of inputs and outputs. You passed " + + std::to_string(inputs.size()) + " inputs and " + std::to_string(outputs.size()) + " outputs."); + + _inputs = inputs; + _outputs = outputs; +} + +template void Dataset::addSample(T input, U output) { + _inputs.push_back(input); + _outputs.push_back(output); +} + +template size_t Dataset::size() const { return _inputs.size(); } + +template std::pair Dataset::operator[](size_t index) const { + if (index > _inputs.size()) + throw std::invalid_argument("While trying to access Dataset, the provided index " + + std::to_string(_inputs.size()) + " was too large for a dataset of size " + + std::to_string(size())); + return {_inputs[index], _outputs[index]}; +} + +template +Dataset::ConstIterator::ConstIterator(size_t index, const Dataset &dataset) + : _index(index), _dataset(dataset) {} + +template const std::pair Dataset::ConstIterator::operator*() { + return _dataset[_index]; +} + +template typename Dataset::ConstIterator &Dataset::ConstIterator::operator++() { + ++_index; + return *this; +} + +template bool Dataset::ConstIterator::operator!=(const ConstIterator &other) { + return _index != other._index; +} + +template typename Dataset::ConstIterator Dataset::begin() const { + return ConstIterator(0, *this); +} + +template typename Dataset::ConstIterator Dataset::end() const { + return ConstIterator(_inputs.size(), *this); +} + +} // namespace shkyera