Skip to content

Commit

Permalink
DataSet and DataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
fszewczyk committed Nov 9, 2023
1 parent 0e641ff commit d8ed9ec
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
build/
data/

docs/html
docs/latex
Expand Down
101 changes: 101 additions & 0 deletions include/nn/data/DataLoader.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved.
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are
* encouraged to buy the author a beer in return.
* Contact the author at [email protected] for inquiries and support.
*/

#pragma once

#include "Dataset.hpp"

namespace shkyera {

template <typename T, typename U> class DataLoader {
private:
const Dataset<T, U> &_dataset;
size_t _batchSize;
bool _shuffle;

public:
DataLoader(const Dataset<T, U> &dataset, size_t batchSize = 4, bool shuffle = false);

size_t getTotalBatches() const;

class ConstIterator {
private:
std::vector<size_t> _order;
size_t _index;

const DataLoader<T, U> &_dataLoader;

public:
ConstIterator(size_t index, const DataLoader<T, U> &dataLoader);

std::pair<std::vector<T>, std::vector<U>> operator*();
ConstIterator &operator++();
bool operator!=(const ConstIterator &other);
};

ConstIterator begin() const;
ConstIterator end() const;
};

template <typename T, typename U>
DataLoader<T, U>::DataLoader(const Dataset<T, U> &dataset, size_t batchSize, bool shuffle)
: _dataset(dataset), _batchSize(batchSize), _shuffle(shuffle) {}

template <typename T, typename U> size_t DataLoader<T, U>::getTotalBatches() const {
size_t batches = _dataset.size() / _batchSize;
if (_dataset.size() % _batchSize != 0)
batches++;
return batches;
}

template <typename T, typename U>
DataLoader<T, U>::ConstIterator::ConstIterator(size_t index, const DataLoader<T, U> &dataLoader)
: _index(index), _dataLoader(dataLoader) {
_order.resize(_dataLoader._dataset.size(), 0);

std::iota(_order.begin(), _order.end(), 0);
if (_dataLoader._shuffle)
utils::shuffle(_order);
}

template <typename T, typename U>
std::pair<std::vector<T>, std::vector<U>> DataLoader<T, U>::ConstIterator::operator*() {
size_t beginIndex = _index;
size_t endIndex = std::min(_index + _dataLoader._batchSize, _dataLoader._dataset.size());

std::vector<T> inputs(endIndex - beginIndex);
std::vector<U> outputs(endIndex - beginIndex);

for (size_t i = beginIndex; i < endIndex; ++i) {
auto [in, out] = _dataLoader._dataset[_order[i]];
inputs[i - beginIndex] = in;
outputs[i - beginIndex] = out;
}

return {inputs, outputs};
}

template <typename T, typename U>
typename DataLoader<T, U>::ConstIterator &DataLoader<T, U>::ConstIterator::operator++() {
_index += _dataLoader._batchSize;
_index = std::min(_index, _dataLoader._dataset.size());
return *this;
}

template <typename T, typename U> bool DataLoader<T, U>::ConstIterator::operator!=(const ConstIterator &other) {
return _index != other._index;
}

template <typename T, typename U> typename DataLoader<T, U>::ConstIterator DataLoader<T, U>::begin() const {
return ConstIterator(0, *this);
}

template <typename T, typename U> typename DataLoader<T, U>::ConstIterator DataLoader<T, U>::end() const {
return ConstIterator(_dataset.size(), *this);
}

} // namespace shkyera
97 changes: 97 additions & 0 deletions include/nn/data/Dataset.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/**
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved.
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are
* encouraged to buy the author a beer in return.
* Contact the author at [email protected] for inquiries and support.
*/

#pragma once

#include <exception>
#include <vector>

namespace shkyera {

template <typename T, typename U> class Dataset {
private:
std::vector<T> _inputs;
std::vector<U> _outputs;

public:
Dataset() = default;
Dataset(const std::vector<T> &inputs, const std::vector<T> &outputs);

void addSample(T input, U output);

size_t size() const;

std::pair<T, U> operator[](size_t index) const;

class ConstIterator {
private:
size_t _index;
const Dataset<T, U> &_dataset;

public:
ConstIterator(size_t index, const Dataset<T, U> &dataset);

const std::pair<T, U> operator*();
ConstIterator &operator++();
bool operator!=(const ConstIterator &other);
};

ConstIterator begin() const;
ConstIterator end() const;
};

template <typename T, typename U> Dataset<T, U>::Dataset(const std::vector<T> &inputs, const std::vector<T> &outputs) {
if (inputs.size() != outputs.size())
throw std::invalid_argument(
"To create a dataset, you have to pass the same amount of inputs and outputs. You passed " +
std::to_string(inputs.size()) + " inputs and " + std::to_string(outputs.size()) + " outputs.");

_inputs = inputs;
_outputs = outputs;
}

template <typename T, typename U> void Dataset<T, U>::addSample(T input, U output) {
_inputs.push_back(input);
_outputs.push_back(output);
}

template <typename T, typename U> size_t Dataset<T, U>::size() const { return _inputs.size(); }

template <typename T, typename U> std::pair<T, U> Dataset<T, U>::operator[](size_t index) const {
if (index > _inputs.size())
throw std::invalid_argument("While trying to access Dataset, the provided index " +
std::to_string(_inputs.size()) + " was too large for a dataset of size " +
std::to_string(size()));
return {_inputs[index], _outputs[index]};
}

template <typename T, typename U>
Dataset<T, U>::ConstIterator::ConstIterator(size_t index, const Dataset<T, U> &dataset)
: _index(index), _dataset(dataset) {}

template <typename T, typename U> const std::pair<T, U> Dataset<T, U>::ConstIterator::operator*() {
return _dataset[_index];
}

template <typename T, typename U> typename Dataset<T, U>::ConstIterator &Dataset<T, U>::ConstIterator::operator++() {
++_index;
return *this;
}

template <typename T, typename U> bool Dataset<T, U>::ConstIterator::operator!=(const ConstIterator &other) {
return _index != other._index;
}

template <typename T, typename U> typename Dataset<T, U>::ConstIterator Dataset<T, U>::begin() const {
return ConstIterator(0, *this);
}

template <typename T, typename U> typename Dataset<T, U>::ConstIterator Dataset<T, U>::end() const {
return ConstIterator(_inputs.size(), *this);
}

} // namespace shkyera

0 comments on commit d8ed9ec

Please sign in to comment.