Skip to content

Commit

Permalink
Iterative Topological Sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
fszewczyk committed Nov 9, 2023
1 parent 6600d5a commit f83d4e8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
1 change: 1 addition & 0 deletions include/ShkyeraGrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#pragma once

#include "core/Type.hpp"
#include "core/Utils.hpp"
#include "core/Value.hpp"
#include "core/Vector.hpp"

Expand Down
10 changes: 10 additions & 0 deletions include/core/Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#pragma once

#include <algorithm>
#include <chrono>
#include <random>
#include <vector>

Expand Down Expand Up @@ -52,4 +53,13 @@ std::enable_if_t<std::is_integral_v<T>, std::vector<T>> sample(T from, T to, siz
return sampled;
}

template <typename Clock = std::chrono::high_resolution_clock> auto startTimer() { return Clock::now(); }

template <typename Clock = std::chrono::high_resolution_clock>
double stopTimer(const typename Clock::time_point &start) {
auto end = Clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
return static_cast<double>(duration.count()) / 1e6;
}

} // namespace shkyera::utils
43 changes: 31 additions & 12 deletions include/core/Value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
#include <functional>
#include <iostream>
#include <memory>
#include <stack>
#include <unordered_set>
#include <vector>

#include "Type.hpp"
#include "Utils.hpp"

namespace shkyera {

Expand All @@ -38,7 +40,8 @@ template <typename T> class Value : public std::enable_shared_from_this<Value<T>
Value(T data);

std::vector<ValuePtr<T>> topologicalSort();
std::vector<ValuePtr<T>> topologicalSort(std::vector<ValuePtr<T>> &sorted, std::unordered_set<Value<T> *> &visited);

inline static double topoSortTime = 0;

public:
friend class Optimizer<T>;
Expand All @@ -51,6 +54,8 @@ template <typename T> class Value : public std::enable_shared_from_this<Value<T>
T getValue();
T getGradient();

static double getTopoTime() { return topoSortTime; }

ValuePtr<T> tanh();
ValuePtr<T> relu();
ValuePtr<T> sigmoid();
Expand Down Expand Up @@ -188,23 +193,37 @@ template <typename T> ValuePtr<T> Value<T>::pow(ValuePtr<T> exponent) {
}

template <typename T> std::vector<ValuePtr<T>> Value<T>::topologicalSort() {
auto timer = utils::startTimer();

std::vector<ValuePtr<T>> sorted;
std::unordered_set<Value<T> *> visited;

return topologicalSort(sorted, visited);
}

template <typename T>
std::vector<ValuePtr<T>> Value<T>::topologicalSort(std::vector<ValuePtr<T>> &sorted,
std::unordered_set<Value<T> *> &visited) {
if (visited.find(this) == visited.end()) {
visited.insert(this);
for (ValuePtr<T> val : _children) {
val->topologicalSort(sorted, visited);
std::stack<Value<T> *> stack;
stack.push(this);

while (!stack.empty()) {
auto cur = stack.top();
if (visited.find(cur) == visited.end()) {
bool hasUnvisitedChildren = false;
for (auto v : cur->_children) {
if (visited.find(v.get()) == visited.end()) {
stack.push(v.get());
hasUnvisitedChildren = true;
}
}

if (!hasUnvisitedChildren) {
stack.pop();
sorted.push_back(cur->shared_from_this());
visited.insert(cur);
}
} else {
stack.pop();
}
sorted.push_back(this->shared_from_this());
}

topoSortTime += utils::stopTimer(timer);

return sorted;
}

Expand Down

0 comments on commit f83d4e8

Please sign in to comment.