Skip to content

Commit

Permalink
Plug thread pool inside sequence and pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
kouchy committed Nov 9, 2024
1 parent 7da9afc commit 1218e05
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 69 deletions.
2 changes: 2 additions & 0 deletions include/Runtime/Pipeline/Pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "Module/Stateful/Adaptor/Adaptor.hpp"
#include "Runtime/Sequence/Sequence.hpp"
#include "Tools/Interface/Interface_get_set_n_frames.hpp"
#include "Tools/Thread/Thread_pool/Thread_pool.hpp"

namespace spu
{
Expand All @@ -24,6 +25,7 @@ class Pipeline : public tools::Interface_get_set_n_frames
protected:
Sequence original_sequence;
std::vector<std::shared_ptr<Sequence>> stages;
std::shared_ptr<tools::Thread_pool> thread_pool;
// clang-format off
std::vector<std::pair<std::vector<std::shared_ptr<module::Adaptor>>,
std::vector<std::shared_ptr<module::Adaptor>>>> adaptors;
Expand Down
5 changes: 4 additions & 1 deletion include/Runtime/Sequence/Sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "Tools/Interface/Interface_clone.hpp"
#include "Tools/Interface/Interface_get_set_n_frames.hpp"
#include "Tools/Interface/Interface_is_done.hpp"
#include "Tools/Thread/Thread_pool/Thread_pool.hpp"

namespace spu
{
Expand Down Expand Up @@ -77,7 +78,9 @@ class Sequence
friend sched::Scheduler;

protected:
size_t n_threads;
const size_t n_threads;
std::shared_ptr<tools::Thread_pool> thread_pool;

std::vector<tools::Digraph_node<Sub_sequence>*> sequences;
std::vector<size_t> firsts_tasks_id;
std::vector<size_t> lasts_tasks_id;
Expand Down
120 changes: 66 additions & 54 deletions src/Runtime/Pipeline/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "Tools/Exception/exception.hpp"
#include "Tools/Interface/Interface_waiting.hpp"
#include "Tools/Thread/Thread_pinning/Thread_pinning_utils.hpp"
#include "Tools/Thread/Thread_pool/Standard/Thread_pool_standard.hpp"

using namespace spu;
using namespace spu::runtime;
Expand Down Expand Up @@ -602,6 +603,9 @@ ::init(const std::vector<TA*> &firsts,

this->create_adaptors(synchro_buffer_sizes, synchro_active_waiting);
this->bind_adaptors();

this->thread_pool.reset(new tools::Thread_pool_standard(this->stages.size() - 1));
this->thread_pool->init(); // threads are spawned here
}

void
Expand Down Expand Up @@ -1240,41 +1244,46 @@ Pipeline::exec(const std::vector<std::function<bool(const std::vector<const int*
throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
}

// ----------------------------------------------------------------------------------------------------------------
auto& stages = this->stages;
std::vector<std::thread> threads;
for (size_t s = 0; s < stages.size() - 1; s++)
std::vector<const std::function<bool(const std::vector<const int*>&)>*> stop_condition_vec(stages.size() - 1,
nullptr);
if (stop_conditions.size() == stages.size())
for (size_t s = 0; s < stages.size() - 1; s++)
stop_condition_vec[s] = &stop_conditions[s];

std::function<void(const size_t)> func_exec = [&stages, &stop_condition_vec](const size_t tid)
{
const std::function<bool(const std::vector<const int*>&)>* stop_condition = nullptr;
if (stop_conditions.size() == this->stages.size()) stop_condition = &stop_conditions[s];

threads.push_back(std::thread(
[&stages, s, stop_condition]()
{
if (stop_condition)
stages[s]->exec(*stop_condition);
else
stages[s]->exec();

// send the signal to stop the next stage
const auto& tasks = stages[s + 1]->get_tasks_per_threads();
for (size_t th = 0; th < tasks.size(); th++)
for (size_t ta = 0; ta < tasks[th].size(); ta++)
{
auto m = dynamic_cast<module::Adaptor*>(&tasks[th][ta]->get_module());
if (m != nullptr)
if (tasks[th][ta]->get_name() == "pull_n" || tasks[th][ta]->get_name() == "pull_1")
m->cancel_waiting();
}
}));
}
size_t s = tid;
if (stop_condition_vec[s])
stages[s]->exec(*(stop_condition_vec[s]));
else
stages[s]->exec();

// send the signal to stop the next stage
const auto& tasks = stages[s + 1]->get_tasks_per_threads();
for (size_t th = 0; th < tasks.size(); th++)
for (size_t ta = 0; ta < tasks[th].size(); ta++)
{
auto m = dynamic_cast<module::Adaptor*>(&tasks[th][ta]->get_module());
if (m != nullptr)
if (tasks[th][ta]->get_name() == "pull_n" || tasks[th][ta]->get_name() == "pull_1")
m->cancel_waiting();
}
};

this->thread_pool->run(func_exec, true);

stages[stages.size() - 1]->exec(stop_conditions[stop_conditions.size() - 1]);

// stop all the stages before
for (size_t notify_s = 0; notify_s < stages.size() - 1; notify_s++)
for (auto& m : stages[notify_s]->get_modules<tools::Interface_waiting>())
m->cancel_waiting();

for (auto& t : threads)
t.join();
this->thread_pool->wait();
this->thread_pool->unset_func_exec();
// ----------------------------------------------------------------------------------------------------------------

// this is NOT made in the tools::Sequence::exec() to correctly flush the pipeline before restoring buffers
// initial configuration
Expand Down Expand Up @@ -1312,41 +1321,44 @@ Pipeline::exec(const std::vector<std::function<bool()>>& stop_conditions)
throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
}

// ----------------------------------------------------------------------------------------------------------------
auto& stages = this->stages;
std::vector<std::thread> threads;
for (size_t s = 0; s < stages.size() - 1; s++)
std::vector<const std::function<bool()>*> stop_condition_vec(stages.size() - 1, nullptr);
if (stop_conditions.size() == stages.size())
for (size_t s = 0; s < stages.size() - 1; s++)
stop_condition_vec[s] = &stop_conditions[s];

std::function<void(const size_t)> func_exec = [&stages, &stop_condition_vec](const size_t tid)
{
const std::function<bool()>* stop_condition = nullptr;
if (stop_conditions.size() == this->stages.size()) stop_condition = &stop_conditions[s];

threads.push_back(std::thread(
[&stages, s, stop_condition]()
{
if (stop_condition)
stages[s]->exec(*stop_condition);
else
stages[s]->exec();

// send the signal to stop the next stage
const auto& tasks = stages[s + 1]->get_tasks_per_threads();
for (size_t th = 0; th < tasks.size(); th++)
for (size_t ta = 0; ta < tasks[th].size(); ta++)
{
auto m = dynamic_cast<module::Adaptor*>(&tasks[th][ta]->get_module());
if (m != nullptr)
if (tasks[th][ta]->get_name() == "pull_n" || tasks[th][ta]->get_name() == "pull_1")
m->cancel_waiting();
}
}));
}
size_t s = tid;
if (stop_condition_vec[s])
stages[s]->exec(*(stop_condition_vec[s]));
else
stages[s]->exec();

// send the signal to stop the next stage
const auto& tasks = stages[s + 1]->get_tasks_per_threads();
for (size_t th = 0; th < tasks.size(); th++)
for (size_t ta = 0; ta < tasks[th].size(); ta++)
{
auto m = dynamic_cast<module::Adaptor*>(&tasks[th][ta]->get_module());
if (m != nullptr)
if (tasks[th][ta]->get_name() == "pull_n" || tasks[th][ta]->get_name() == "pull_1")
m->cancel_waiting();
}
};

this->thread_pool->run(func_exec, true);
stages[stages.size() - 1]->exec(stop_conditions[stop_conditions.size() - 1]);

// stop all the stages before
for (size_t notify_s = 0; notify_s < stages.size() - 1; notify_s++)
for (auto& m : stages[notify_s]->get_modules<tools::Interface_waiting>())
m->cancel_waiting();

for (auto& t : threads)
t.join();
this->thread_pool->wait();
this->thread_pool->unset_func_exec();
// ----------------------------------------------------------------------------------------------------------------

// this is NOT made in the tools::Sequence::exec() to correctly flush the pipeline before restoring buffers
// initial configuration
Expand Down
28 changes: 14 additions & 14 deletions src/Runtime/Sequence/Sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "Tools/Signal_handler/Signal_handler.hpp"
#include "Tools/Thread/Thread_pinning/Thread_pinning.hpp"
#include "Tools/Thread/Thread_pinning/Thread_pinning_utils.hpp"
#include "Tools/Thread/Thread_pool/Standard/Thread_pool_standard.hpp"

using namespace spu;
using namespace spu::runtime;
Expand Down Expand Up @@ -509,6 +510,9 @@ Sequence::init(const std::vector<TA*>& firsts, const std::vector<TA*>& lasts, co

for (size_t tid = 0; tid < this->sequences.size(); tid++)
this->cur_ss[tid] = this->sequences[tid];

this->thread_pool.reset(new tools::Thread_pool_standard(this->n_threads - 1));
this->thread_pool->init(); // threads are spawned here
}

Sequence*
Expand Down Expand Up @@ -846,15 +850,14 @@ Sequence::exec(std::function<bool(const std::vector<const int*>&)> stop_conditio
else
real_stop_condition = stop_condition;

std::vector<std::thread> threads(n_threads);
for (size_t tid = 1; tid < n_threads; tid++)
threads[tid] =
std::thread(&Sequence::_exec, this, tid, std::ref(real_stop_condition), std::ref(this->sequences[tid]));
std::function<void(const size_t)> func_exec = [this, &real_stop_condition](const size_t tid)
{ this->Sequence::_exec(tid + 1, real_stop_condition, this->sequences[tid + 1]); };

this->thread_pool->run(func_exec, true);
this->_exec(0, real_stop_condition, this->sequences[0]);
this->thread_pool->wait();

for (size_t tid = 1; tid < n_threads; tid++)
threads[tid].join();
this->thread_pool->unset_func_exec();

if (this->is_no_copy_mode() && !this->is_part_of_pipeline)
{
Expand Down Expand Up @@ -884,17 +887,14 @@ Sequence::exec(std::function<bool()> stop_condition)
else
real_stop_condition = stop_condition;

std::vector<std::thread> threads(n_threads);
for (size_t tid = 1; tid < n_threads; tid++)
{
threads[tid] = std::thread(
&Sequence::_exec_without_statuses, this, tid, std::ref(real_stop_condition), std::ref(this->sequences[tid]));
}
std::function<void(const size_t)> func_exec = [this, &real_stop_condition](const size_t tid)
{ this->Sequence::_exec_without_statuses(tid + 1, real_stop_condition, this->sequences[tid + 1]); };

this->thread_pool->run(func_exec, true);
this->_exec_without_statuses(0, real_stop_condition, this->sequences[0]);
this->thread_pool->wait();

for (size_t tid = 1; tid < n_threads; tid++)
threads[tid].join();
this->thread_pool->unset_func_exec();

if (this->is_no_copy_mode() && !this->is_part_of_pipeline)
{
Expand Down

0 comments on commit 1218e05

Please sign in to comment.