Skip to content

Commit

Permalink
making workqueue push_-/pop_task lock-free
Browse files Browse the repository at this point in the history
Summary:
This is relatively straightforward change that makes the `push_task` and `pop_task` operations lock-free.

To achieve this, we split the `m_queue_mtx` protected `m_queue` into two parts:
- A vector of initial tasks, `m_initial_tasks`. This is drained by atomically incrementing a `m_next_initial_task` counter, until we hit the (immutable) `m_initial_tasks.size()`.
- A linked list of additionally pushed tasks, `m_additional_tasks`. This is mutated via atomic `compare_exchange` operations. A little complication is that we can't eagerly destroy popped nodes, as racing pop-operations might reference them, so we maintain another linked list of erased node that eventually get destroyed in the destructor.

This is a behavior-preserving change.

Reviewed By: yuxuanchen1997

Differential Revision: D50952675

fbshipit-source-id: 4d438b295b117ea8c78d700c3a3f943dbda6736b
  • Loading branch information
Nikolai Tillmann authored and facebook-github-bot committed Nov 17, 2023
1 parent 1169daf commit 036aca0
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 32 deletions.
120 changes: 88 additions & 32 deletions include/sparta/WorkQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <exception>
#include <mutex>
#include <numeric>
#include <queue>
#include <random>
#include <thread>
#include <utility>
Expand Down Expand Up @@ -95,19 +94,22 @@ class Semaphore {
};

struct StateCounters {
std::atomic_uint num_non_empty;
std::atomic_uint num_non_empty_initial;
std::atomic_uint num_non_empty_additional;
std::atomic_uint num_running;
const unsigned int num_all;
// Mutexes aren't move-able.
std::unique_ptr<Semaphore> waiter;

explicit StateCounters(unsigned int num)
: num_non_empty(0),
: num_non_empty_initial(0),
num_non_empty_additional(0),
num_running(0),
num_all(num),
waiter(new Semaphore(0)) {}
StateCounters(StateCounters&& other)
: num_non_empty(other.num_non_empty.load()),
: num_non_empty_initial(other.num_non_empty_initial.load()),
num_non_empty_additional(other.num_non_empty_additional.load()),
num_running(other.num_running.load()),
num_all(other.num_all),
waiter(std::move(other.waiter)) {}
Expand All @@ -131,48 +133,96 @@ class WorkerState final {
*/
void push_task(Input task) {
assert(m_can_push_task);
std::lock_guard<std::mutex> guard(m_queue_mtx);
if (m_queue.empty()) {
++m_state_counters->num_non_empty;
}
if (m_state_counters->num_running < m_state_counters->num_all) {
auto* node = new Node{std::move(task), m_additional_tasks.load()};
do {
if (node->prev == nullptr) {
// Increment before updating additional_tasks, as some other thread
// might steal our new additional task immediately
m_state_counters->num_non_empty_additional.fetch_add(1);
}
// Try to add a new head to the list
} while (!m_additional_tasks.compare_exchange_strong(node->prev, node));
if (m_state_counters->num_running.load() < m_state_counters->num_all) {
m_state_counters->waiter->give(1u); // May consider waking all.
}
m_queue.push(std::move(task));
}

size_t worker_id() const { return m_id; }

void set_running(bool running) {
if (m_running && !running) {
assert(m_state_counters->num_running > 0);
--m_state_counters->num_running;
auto num = m_state_counters->num_running.fetch_sub(1);
assert(num > 0);
} else if (!m_running && running) {
++m_state_counters->num_running;
m_state_counters->num_running.fetch_add(1);
}
m_running = running;
};

~WorkerState() {
for (auto* erased = m_erased.load(); erased != nullptr;) {
auto* prev = erased->prev;
delete erased;
erased = prev;
}
}

private:
boost::optional<Input> pop_task(WorkerState<Input>* other) {
std::lock_guard<std::mutex> guard(m_queue_mtx);
if (!m_queue.empty()) {
auto i = m_next_initial_task.load();
auto size = m_initial_tasks.size();
// If i < size, (try to) increment.
while (i < size && !m_next_initial_task.compare_exchange_strong(i, i + 1)) {
}
// If we successfully incremented, we can pop.
if (i < size) {
if (size - 1 == i) {
auto num = m_state_counters->num_non_empty_initial.fetch_sub(1);
assert(num > 0);
}
other->set_running(true);
if (m_queue.size() == 1) {
assert(m_state_counters->num_non_empty > 0);
--m_state_counters->num_non_empty;
return boost::optional<Input>(std::move(m_initial_tasks.at(i)));
}

auto* node = m_additional_tasks.load();
// Try to remove head from list
while (node != nullptr) {
bool exchanged =
m_additional_tasks.compare_exchange_strong(node, node->prev);
if (exchanged) {
// We successfully dequeued an element,
// node holds the element we intend to remove.
if (node->prev == nullptr) {
auto num = m_state_counters->num_non_empty_additional.fetch_sub(1);
assert(num > 0);
}
// We can't just delete the node right here, as there may be racing
// pop_tasks that read the `->prev` field above. So we stack it away in
// a different list that gets destructed later.
node->prev = m_erased.load();
while (!m_erased.compare_exchange_strong(node->prev, node)) {
}
other->set_running(true);
return boost::optional<Input>(std::move(node->task));
}
auto task = std::move(m_queue.front());
m_queue.pop();
return boost::optional<Input>(std::move(task));
// Otherwise, we depend on the behaviour of
// compare_exchange_strong to update `node` with the actual
// contained value.
}

return boost::none;
}

size_t m_id;
bool m_running{false};
std::queue<Input> m_queue;
std::mutex m_queue_mtx;
std::vector<Input> m_initial_tasks;
std::atomic<size_t> m_next_initial_task{0};
struct Node {
Input task;
Node* prev;
};
std::atomic<Node*> m_additional_tasks{nullptr};
std::atomic<Node*> m_erased{nullptr};
workqueue_impl::StateCounters* m_state_counters;
const bool m_can_push_task{false};

Expand Down Expand Up @@ -252,13 +302,13 @@ template <class Input, typename Executor>
void WorkQueue<Input, Executor>::add_item(Input task) {
m_insert_idx = (m_insert_idx + 1) % m_num_threads;
assert(m_insert_idx < m_states.size());
m_states[m_insert_idx]->m_queue.push(std::move(task));
m_states[m_insert_idx]->m_initial_tasks.push_back(std::move(task));
}

template <class Input, typename Executor>
void WorkQueue<Input, Executor>::add_item(Input task, size_t worker_id) {
assert(worker_id < m_states.size());
m_states[worker_id]->m_queue.push(std::move(task));
m_states[worker_id]->m_initial_tasks.push_back(std::move(task));
}

/*
Expand All @@ -267,8 +317,9 @@ void WorkQueue<Input, Executor>::add_item(Input task, size_t worker_id) {
*/
template <class Input, typename Executor>
void WorkQueue<Input, Executor>::run_all() {
m_state_counters.num_non_empty = 0;
m_state_counters.num_running = 0;
m_state_counters.num_non_empty_initial.store(0, std::memory_order_relaxed);
m_state_counters.num_non_empty_additional.store(0, std::memory_order_relaxed);
m_state_counters.num_running.store(0, std::memory_order_relaxed);
m_state_counters.waiter->take_all();
std::mutex exception_mutex;
std::exception_ptr exception;
Expand Down Expand Up @@ -300,8 +351,9 @@ void WorkQueue<Input, Executor>::run_all() {

// Let the thread quit if all the threads are not running and there
// is no task in any queue.
if (m_state_counters.num_running == 0 &&
m_state_counters.num_non_empty == 0) {
if (m_state_counters.num_running.load() == 0 &&
m_state_counters.num_non_empty_initial.load() == 0 &&
m_state_counters.num_non_empty_additional.load() == 0) {
// Wake up everyone who might be waiting, so they can quit.
m_state_counters.waiter->give(m_state_counters.num_all);
return;
Expand Down Expand Up @@ -336,8 +388,9 @@ void WorkQueue<Input, Executor>::run_all() {
};

for (size_t i = 0; i < m_num_threads; ++i) {
if (!m_states[i]->m_queue.empty()) {
++m_state_counters.num_non_empty;
if (!m_states[i]->m_initial_tasks.empty()) {
m_state_counters.num_non_empty_initial.fetch_add(
1, std::memory_order_relaxed);
}
}

Expand All @@ -352,7 +405,10 @@ void WorkQueue<Input, Executor>::run_all() {
}

for (size_t i = 0; i < m_num_threads; ++i) {
assert(m_states[i]->m_queue.empty());
assert(m_states[i]->m_next_initial_task.load(std::memory_order_relaxed) ==
m_states[i]->m_initial_tasks.size());
assert(m_states[i]->m_additional_tasks.load(std::memory_order_relaxed) ==
nullptr);
}
}

Expand Down
27 changes: 27 additions & 0 deletions test/WorkQueueTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ TEST(WorkQueueTest, checkDynamicallyAddingTasks) {
EXPECT_EQ(55, result);
}

// Similar checkDynamicallyAddingTasks, but do much more work.
TEST(WorkQueueTest, stress) {
for (size_t num_threads = 8; num_threads <= 128; num_threads *= 2) {
std::atomic<int> result{0};
auto wq = sparta::work_queue<int>(
[&](sparta::WorkerState<int>* worker_state, int a) {
if (a > 0) {
worker_state->push_task(a - 1);
result++;
}
},
num_threads,
/*push_tasks_while_running=*/true);
const size_t N = 200;
for (size_t i = 0; i <= N; i++) {
wq.add_item(10 * i);
}
wq.run_all();

// 10 * N + 10 * (N - 1) + ... + 10
// = 10 * (N + (N - 1) + ... + 1 + 0)
// = 10 * (N * (N + 1) / 2)
// = 201000 // for N = 200
EXPECT_EQ(201000, result);
}
}

TEST(WorkQueueTest, preciseScheduling) {
std::array<int, NUM_INTS> array = {0};

Expand Down

0 comments on commit 036aca0

Please sign in to comment.