diff --git a/include/ShkyeraGrad.hpp b/include/ShkyeraGrad.hpp index 1c3a480..45541a4 100644 --- a/include/ShkyeraGrad.hpp +++ b/include/ShkyeraGrad.hpp @@ -8,6 +8,7 @@ #pragma once #include "core/Type.hpp" +#include "core/Utils.hpp" #include "core/Value.hpp" #include "core/Vector.hpp" diff --git a/include/core/Utils.hpp b/include/core/Utils.hpp index 9d31731..aeaa34d 100644 --- a/include/core/Utils.hpp +++ b/include/core/Utils.hpp @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -52,4 +53,13 @@ std::enable_if_t, std::vector> sample(T from, T to, siz return sampled; } +template auto startTimer() { return Clock::now(); } + +template +double stopTimer(const typename Clock::time_point &start) { + auto end = Clock::now(); + auto duration = std::chrono::duration_cast(end - start); + return static_cast(duration.count()) / 1e6; +} + } // namespace shkyera::utils diff --git a/include/core/Value.hpp b/include/core/Value.hpp index 08fcbcd..ea3b0c1 100644 --- a/include/core/Value.hpp +++ b/include/core/Value.hpp @@ -11,10 +11,12 @@ #include #include #include +#include #include #include #include "Type.hpp" +#include "Utils.hpp" namespace shkyera { @@ -38,7 +40,8 @@ template class Value : public std::enable_shared_from_this Value(T data); std::vector> topologicalSort(); - std::vector> topologicalSort(std::vector> &sorted, std::unordered_set *> &visited); + + inline static double topoSortTime = 0; public: friend class Optimizer; @@ -51,6 +54,8 @@ template class Value : public std::enable_shared_from_this T getValue(); T getGradient(); + static double getTopoTime() { return topoSortTime; } + ValuePtr tanh(); ValuePtr relu(); ValuePtr sigmoid(); @@ -188,23 +193,37 @@ template ValuePtr Value::pow(ValuePtr exponent) { } template std::vector> Value::topologicalSort() { + auto timer = utils::startTimer(); + std::vector> sorted; std::unordered_set *> visited; - return topologicalSort(sorted, visited); -} - -template -std::vector> Value::topologicalSort(std::vector> &sorted, - std::unordered_set *> &visited) { - if (visited.find(this) == visited.end()) { - visited.insert(this); - for (ValuePtr val : _children) { - val->topologicalSort(sorted, visited); + std::stack *> stack; + stack.push(this); + + while (!stack.empty()) { + auto cur = stack.top(); + if (visited.find(cur) == visited.end()) { + bool hasUnvisitedChildren = false; + for (auto v : cur->_children) { + if (visited.find(v.get()) == visited.end()) { + stack.push(v.get()); + hasUnvisitedChildren = true; + } + } + + if (!hasUnvisitedChildren) { + stack.pop(); + sorted.push_back(cur->shared_from_this()); + visited.insert(cur); + } + } else { + stack.pop(); } - sorted.push_back(this->shared_from_this()); } + topoSortTime += utils::stopTimer(timer); + return sorted; }