Skip to content

Commit

Permalink
Merge pull request #9 from fszewczyk/refactoring
Browse files Browse the repository at this point in the history
Vector Iterator
  • Loading branch information
fszewczyk authored Nov 8, 2023
2 parents cf2e100 + e30b15e commit 43f5fa6
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/xor_classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ int main() {
Loss::Function32 lossFunction = Loss::CrossEntropy<Type::float32>;

// ------ TRAINING THE NETWORK ------- //
for (size_t epoch = 0; epoch < 200; epoch++) {
for (size_t epoch = 0; epoch < 100; epoch++) {
auto epochLoss = Val32::create(0);

optimizer.reset();
Expand Down
41 changes: 40 additions & 1 deletion include/core/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
namespace shkyera {

template <typename T> class Vector;

using Vec32 = Vector<Type::float32>;
using Vec64 = Vector<Type::float64>;

Expand Down Expand Up @@ -43,6 +44,21 @@ template <typename T> class Vector {
Vector<T> &operator*=(ValuePtr<T> val);

ValuePtr<T> operator[](size_t index) const;

class ConstIterator {
private:
size_t _index;
const Vector<T> &_vector;

public:
ConstIterator(size_t index, const Vector<T> &vector);
const ValuePtr<T> operator*();
ConstIterator &operator++();
bool operator!=(const ConstIterator &other);
};

ConstIterator begin() const;
ConstIterator end() const;
};

template <typename T> Vector<T>::Vector(std::vector<ValuePtr<T>> values) { _values = values; }
Expand Down Expand Up @@ -134,11 +150,34 @@ template <typename T> ValuePtr<T> Vector<T>::operator[](size_t index) const { re
template <typename T> std::ostream &operator<<(std::ostream &os, const Vector<T> &vector) {
os << "Vector(size=" << vector.size() << ", data={";

for (const ValuePtr<T> val : vector._values)
for (auto val : vector)
os << val << ' ';

os << "})";
return os;
}

template <typename T> typename Vector<T>::ConstIterator Vector<T>::begin() const { return ConstIterator(0, *this); }
template <typename T> typename Vector<T>::ConstIterator Vector<T>::end() const { return ConstIterator(size(), *this); }

template <typename T>
Vector<T>::ConstIterator::ConstIterator(size_t index, const Vector<T> &vector) : _index(index), _vector(vector) {}

template <typename T> const ValuePtr<T> Vector<T>::ConstIterator::operator*() {
if (_index < _vector.size()) {
return _vector[_index];
}
throw std::out_of_range("Vector iterator out of range. Tried to access index " + std::to_string(_index) +
" in a Vector of size " + std::to_string(_vector.size()) + ".");
}

template <typename T> typename Vector<T>::ConstIterator &Vector<T>::ConstIterator::operator++() {
++_index;
return *this;
}

template <typename T> bool Vector<T>::ConstIterator::operator!=(const ConstIterator &other) {
return _index != other._index;
}

} // namespace shkyera
4 changes: 2 additions & 2 deletions include/nn/Neuron.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ template <typename T> std::vector<ValuePtr<T>> Neuron<T>::parameters() const {
std::vector<ValuePtr<T>> params;
params.reserve(_weights.size() + 1);

for (size_t i = 0; i < _weights.size(); ++i)
params.push_back(_weights[i]);
for (auto &w : _weights)
params.push_back(w);

params.push_back(_bias);

Expand Down
5 changes: 2 additions & 3 deletions include/nn/activation/Exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ template <typename T> Vector<T> Exp<T>::operator()(const Vector<T> &x) const {
std::vector<ValuePtr<T>> out;
out.reserve(x.size());

for (size_t i = 0; i < x.size(); ++i) {
out.emplace_back(x[i]->exp());
}
for (auto &entry : x)
out.emplace_back(entry->exp());

return Vector<T>(out);
}
Expand Down
5 changes: 2 additions & 3 deletions include/nn/activation/ReLU.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ template <typename T> Vector<T> ReLU<T>::operator()(const Vector<T> &x) const {
std::vector<ValuePtr<T>> out;
out.reserve(x.size());

for (size_t i = 0; i < x.size(); ++i) {
out.emplace_back(x[i]->relu());
}
for (auto &entry : x)
out.emplace_back(entry->relu());

return Vector<T>(out);
}
Expand Down
4 changes: 2 additions & 2 deletions include/nn/activation/Sigmoid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ template <typename T> Vector<T> Sigmoid<T>::operator()(const Vector<T> &x) const
std::vector<ValuePtr<T>> out;
out.reserve(x.size());

for (size_t i = 0; i < x.size(); ++i) {
out.emplace_back(x[i]->sigmoid());
for (auto &entry : x) {
out.emplace_back(entry->sigmoid());
}

return Vector<T>(out);
Expand Down
10 changes: 5 additions & 5 deletions include/nn/activation/Softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ template <typename T> Vector<T> Softmax<T>::operator()(const Vector<T> &x) const
out.reserve(x.size());

auto maxValue = Value<T>::create(x[0]->getValue());
for (size_t i = 1; i < x.size(); ++i)
if (x[i] > maxValue)
maxValue = x[i];
for (auto &entry : x)
if (entry > maxValue)
maxValue = entry;

auto sumExponentiated = Value<T>::create(0);
for (size_t i = 0; i < x.size(); ++i) {
auto exponentiated = (x[i] - maxValue)->exp();
for (auto &entry : x) {
auto exponentiated = (entry - maxValue)->exp();
out.emplace_back(exponentiated);
sumExponentiated = sumExponentiated + exponentiated;
}
Expand Down
5 changes: 2 additions & 3 deletions include/nn/activation/Tanh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ template <typename T> Vector<T> Tanh<T>::operator()(const Vector<T> &x) const {
std::vector<ValuePtr<T>> out;
out.reserve(x.size());

for (size_t i = 0; i < x.size(); ++i) {
out.emplace_back(x[i]->tanh());
}
for (auto &entry : x)
out.emplace_back(entry->tanh());

return Vector<T>(out);
}
Expand Down
4 changes: 2 additions & 2 deletions include/nn/layers/Dropout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ template <typename T> DropoutPtr<T> Dropout<T>::create(size_t input, size_t size
template <typename T> Vector<T> Dropout<T>::operator()(const Vector<T> &x) const {
std::vector<ValuePtr<T>> alteredInput;
alteredInput.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i)
alteredInput.push_back(x[i]);
for (const ValuePtr<T> &val : x)
alteredInput.push_back(val);

std::vector<size_t> indicesToRemove = utils::sample<size_t>(0, x.size() - 1, _dropout * x.size(), false);
for (size_t idxToRemove : indicesToRemove)
Expand Down

0 comments on commit 43f5fa6

Please sign in to comment.