diff --git a/examples/xor_classification.cpp b/examples/xor_classification.cpp index d4961c5..30ff5fa 100644 --- a/examples/xor_classification.cpp +++ b/examples/xor_classification.cpp @@ -27,7 +27,7 @@ int main() { Loss::Function32 lossFunction = Loss::CrossEntropy; // ------ TRAINING THE NETWORK ------- // - for (size_t epoch = 0; epoch < 200; epoch++) { + for (size_t epoch = 0; epoch < 100; epoch++) { auto epochLoss = Val32::create(0); optimizer.reset(); diff --git a/include/core/Vector.hpp b/include/core/Vector.hpp index b13783a..e730cb9 100644 --- a/include/core/Vector.hpp +++ b/include/core/Vector.hpp @@ -15,6 +15,7 @@ namespace shkyera { template class Vector; + using Vec32 = Vector; using Vec64 = Vector; @@ -43,6 +44,21 @@ template class Vector { Vector &operator*=(ValuePtr val); ValuePtr operator[](size_t index) const; + + class ConstIterator { + private: + size_t _index; + const Vector &_vector; + + public: + ConstIterator(size_t index, const Vector &vector); + const ValuePtr operator*(); + ConstIterator &operator++(); + bool operator!=(const ConstIterator &other); + }; + + ConstIterator begin() const; + ConstIterator end() const; }; template Vector::Vector(std::vector> values) { _values = values; } @@ -134,11 +150,34 @@ template ValuePtr Vector::operator[](size_t index) const { re template std::ostream &operator<<(std::ostream &os, const Vector &vector) { os << "Vector(size=" << vector.size() << ", data={"; - for (const ValuePtr val : vector._values) + for (auto val : vector) os << val << ' '; os << "})"; return os; } +template typename Vector::ConstIterator Vector::begin() const { return ConstIterator(0, *this); } +template typename Vector::ConstIterator Vector::end() const { return ConstIterator(size(), *this); } + +template +Vector::ConstIterator::ConstIterator(size_t index, const Vector &vector) : _index(index), _vector(vector) {} + +template const ValuePtr Vector::ConstIterator::operator*() { + if (_index < _vector.size()) { + return _vector[_index]; + } + throw std::out_of_range("Vector iterator out of range. Tried to access index " + std::to_string(_index) + + " in a Vector of size " + std::to_string(_vector.size()) + "."); +} + +template typename Vector::ConstIterator &Vector::ConstIterator::operator++() { + ++_index; + return *this; +} + +template bool Vector::ConstIterator::operator!=(const ConstIterator &other) { + return _index != other._index; +} + } // namespace shkyera diff --git a/include/nn/Neuron.hpp b/include/nn/Neuron.hpp index a4cefa5..6d148de 100644 --- a/include/nn/Neuron.hpp +++ b/include/nn/Neuron.hpp @@ -46,8 +46,8 @@ template std::vector> Neuron::parameters() const { std::vector> params; params.reserve(_weights.size() + 1); - for (size_t i = 0; i < _weights.size(); ++i) - params.push_back(_weights[i]); + for (auto &w : _weights) + params.push_back(w); params.push_back(_bias); diff --git a/include/nn/activation/Exp.hpp b/include/nn/activation/Exp.hpp index 8ec7a43..b97bd1d 100644 --- a/include/nn/activation/Exp.hpp +++ b/include/nn/activation/Exp.hpp @@ -28,9 +28,8 @@ template Vector Exp::operator()(const Vector &x) const { std::vector> out; out.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) { - out.emplace_back(x[i]->exp()); - } + for (auto &entry : x) + out.emplace_back(entry->exp()); return Vector(out); } diff --git a/include/nn/activation/ReLU.hpp b/include/nn/activation/ReLU.hpp index b1757c7..7c0f32b 100644 --- a/include/nn/activation/ReLU.hpp +++ b/include/nn/activation/ReLU.hpp @@ -28,9 +28,8 @@ template Vector ReLU::operator()(const Vector &x) const { std::vector> out; out.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) { - out.emplace_back(x[i]->relu()); - } + for (auto &entry : x) + out.emplace_back(entry->relu()); return Vector(out); } diff --git a/include/nn/activation/Sigmoid.hpp b/include/nn/activation/Sigmoid.hpp index 8648cfe..cbc4ed7 100644 --- a/include/nn/activation/Sigmoid.hpp +++ b/include/nn/activation/Sigmoid.hpp @@ -30,8 +30,8 @@ template Vector Sigmoid::operator()(const Vector &x) const std::vector> out; out.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) { - out.emplace_back(x[i]->sigmoid()); + for (auto &entry : x) { + out.emplace_back(entry->sigmoid()); } return Vector(out); diff --git a/include/nn/activation/Softmax.hpp b/include/nn/activation/Softmax.hpp index 6f80c09..004a99c 100644 --- a/include/nn/activation/Softmax.hpp +++ b/include/nn/activation/Softmax.hpp @@ -31,13 +31,13 @@ template Vector Softmax::operator()(const Vector &x) const out.reserve(x.size()); auto maxValue = Value::create(x[0]->getValue()); - for (size_t i = 1; i < x.size(); ++i) - if (x[i] > maxValue) - maxValue = x[i]; + for (auto &entry : x) + if (entry > maxValue) + maxValue = entry; auto sumExponentiated = Value::create(0); - for (size_t i = 0; i < x.size(); ++i) { - auto exponentiated = (x[i] - maxValue)->exp(); + for (auto &entry : x) { + auto exponentiated = (entry - maxValue)->exp(); out.emplace_back(exponentiated); sumExponentiated = sumExponentiated + exponentiated; } diff --git a/include/nn/activation/Tanh.hpp b/include/nn/activation/Tanh.hpp index 1bed7ed..5066526 100644 --- a/include/nn/activation/Tanh.hpp +++ b/include/nn/activation/Tanh.hpp @@ -28,9 +28,8 @@ template Vector Tanh::operator()(const Vector &x) const { std::vector> out; out.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) { - out.emplace_back(x[i]->tanh()); - } + for (auto &entry : x) + out.emplace_back(entry->tanh()); return Vector(out); } diff --git a/include/nn/layers/Dropout.hpp b/include/nn/layers/Dropout.hpp index c3a6063..be3966a 100644 --- a/include/nn/layers/Dropout.hpp +++ b/include/nn/layers/Dropout.hpp @@ -45,8 +45,8 @@ template DropoutPtr Dropout::create(size_t input, size_t size template Vector Dropout::operator()(const Vector &x) const { std::vector> alteredInput; alteredInput.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) - alteredInput.push_back(x[i]); + for (const ValuePtr &val : x) + alteredInput.push_back(val); std::vector indicesToRemove = utils::sample(0, x.size() - 1, _dropout * x.size(), false); for (size_t idxToRemove : indicesToRemove)