diff --git a/include/sparta/WorkQueue.h b/include/sparta/WorkQueue.h index e3070cc..c15179f 100644 --- a/include/sparta/WorkQueue.h +++ b/include/sparta/WorkQueue.h @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -95,19 +94,22 @@ class Semaphore { }; struct StateCounters { - std::atomic_uint num_non_empty; + std::atomic_uint num_non_empty_initial; + std::atomic_uint num_non_empty_additional; std::atomic_uint num_running; const unsigned int num_all; // Mutexes aren't move-able. std::unique_ptr waiter; explicit StateCounters(unsigned int num) - : num_non_empty(0), + : num_non_empty_initial(0), + num_non_empty_additional(0), num_running(0), num_all(num), waiter(new Semaphore(0)) {} StateCounters(StateCounters&& other) - : num_non_empty(other.num_non_empty.load()), + : num_non_empty_initial(other.num_non_empty_initial.load()), + num_non_empty_additional(other.num_non_empty_additional.load()), num_running(other.num_running.load()), num_all(other.num_all), waiter(std::move(other.waiter)) {} @@ -131,48 +133,96 @@ class WorkerState final { */ void push_task(Input task) { assert(m_can_push_task); - std::lock_guard guard(m_queue_mtx); - if (m_queue.empty()) { - ++m_state_counters->num_non_empty; - } - if (m_state_counters->num_running < m_state_counters->num_all) { + auto* node = new Node{std::move(task), m_additional_tasks.load()}; + do { + if (node->prev == nullptr) { + // Increment before updating additional_tasks, as some other thread + // might steal our new additional task immediately + m_state_counters->num_non_empty_additional.fetch_add(1); + } + // Try to add a new head to the list + } while (!m_additional_tasks.compare_exchange_strong(node->prev, node)); + if (m_state_counters->num_running.load() < m_state_counters->num_all) { m_state_counters->waiter->give(1u); // May consider waking all. } - m_queue.push(std::move(task)); } size_t worker_id() const { return m_id; } void set_running(bool running) { if (m_running && !running) { - assert(m_state_counters->num_running > 0); - --m_state_counters->num_running; + auto num = m_state_counters->num_running.fetch_sub(1); + assert(num > 0); } else if (!m_running && running) { - ++m_state_counters->num_running; + m_state_counters->num_running.fetch_add(1); } m_running = running; }; + ~WorkerState() { + for (auto* erased = m_erased.load(); erased != nullptr;) { + auto* prev = erased->prev; + delete erased; + erased = prev; + } + } + private: boost::optional pop_task(WorkerState* other) { - std::lock_guard guard(m_queue_mtx); - if (!m_queue.empty()) { + auto i = m_next_initial_task.load(); + auto size = m_initial_tasks.size(); + // If i < size, (try to) increment. + while (i < size && !m_next_initial_task.compare_exchange_strong(i, i + 1)) { + } + // If we successfully incremented, we can pop. + if (i < size) { + if (size - 1 == i) { + auto num = m_state_counters->num_non_empty_initial.fetch_sub(1); + assert(num > 0); + } other->set_running(true); - if (m_queue.size() == 1) { - assert(m_state_counters->num_non_empty > 0); - --m_state_counters->num_non_empty; + return boost::optional(std::move(m_initial_tasks.at(i))); + } + + auto* node = m_additional_tasks.load(); + // Try to remove head from list + while (node != nullptr) { + bool exchanged = + m_additional_tasks.compare_exchange_strong(node, node->prev); + if (exchanged) { + // We successfully dequeued an element, + // node holds the element we intend to remove. + if (node->prev == nullptr) { + auto num = m_state_counters->num_non_empty_additional.fetch_sub(1); + assert(num > 0); + } + // We can't just delete the node right here, as there may be racing + // pop_tasks that read the `->prev` field above. So we stack it away in + // a different list that gets destructed later. + node->prev = m_erased.load(); + while (!m_erased.compare_exchange_strong(node->prev, node)) { + } + other->set_running(true); + return boost::optional(std::move(node->task)); } - auto task = std::move(m_queue.front()); - m_queue.pop(); - return boost::optional(std::move(task)); + // Otherwise, we depend on the behaviour of + // compare_exchange_strong to update `node` with the actual + // contained value. } + return boost::none; } size_t m_id; bool m_running{false}; - std::queue m_queue; - std::mutex m_queue_mtx; + std::vector m_initial_tasks; + std::atomic m_next_initial_task{0}; + struct Node { + Input task; + Node* prev; + }; + std::atomic m_additional_tasks{nullptr}; + std::atomic m_erased{nullptr}; workqueue_impl::StateCounters* m_state_counters; const bool m_can_push_task{false}; @@ -252,13 +302,13 @@ template void WorkQueue::add_item(Input task) { m_insert_idx = (m_insert_idx + 1) % m_num_threads; assert(m_insert_idx < m_states.size()); - m_states[m_insert_idx]->m_queue.push(std::move(task)); + m_states[m_insert_idx]->m_initial_tasks.push_back(std::move(task)); } template void WorkQueue::add_item(Input task, size_t worker_id) { assert(worker_id < m_states.size()); - m_states[worker_id]->m_queue.push(std::move(task)); + m_states[worker_id]->m_initial_tasks.push_back(std::move(task)); } /* @@ -267,8 +317,9 @@ void WorkQueue::add_item(Input task, size_t worker_id) { */ template void WorkQueue::run_all() { - m_state_counters.num_non_empty = 0; - m_state_counters.num_running = 0; + m_state_counters.num_non_empty_initial.store(0, std::memory_order_relaxed); + m_state_counters.num_non_empty_additional.store(0, std::memory_order_relaxed); + m_state_counters.num_running.store(0, std::memory_order_relaxed); m_state_counters.waiter->take_all(); std::mutex exception_mutex; std::exception_ptr exception; @@ -300,8 +351,9 @@ void WorkQueue::run_all() { // Let the thread quit if all the threads are not running and there // is no task in any queue. - if (m_state_counters.num_running == 0 && - m_state_counters.num_non_empty == 0) { + if (m_state_counters.num_running.load() == 0 && + m_state_counters.num_non_empty_initial.load() == 0 && + m_state_counters.num_non_empty_additional.load() == 0) { // Wake up everyone who might be waiting, so they can quit. m_state_counters.waiter->give(m_state_counters.num_all); return; @@ -336,8 +388,9 @@ void WorkQueue::run_all() { }; for (size_t i = 0; i < m_num_threads; ++i) { - if (!m_states[i]->m_queue.empty()) { - ++m_state_counters.num_non_empty; + if (!m_states[i]->m_initial_tasks.empty()) { + m_state_counters.num_non_empty_initial.fetch_add( + 1, std::memory_order_relaxed); } } @@ -352,7 +405,10 @@ void WorkQueue::run_all() { } for (size_t i = 0; i < m_num_threads; ++i) { - assert(m_states[i]->m_queue.empty()); + assert(m_states[i]->m_next_initial_task.load(std::memory_order_relaxed) == + m_states[i]->m_initial_tasks.size()); + assert(m_states[i]->m_additional_tasks.load(std::memory_order_relaxed) == + nullptr); } } diff --git a/test/WorkQueueTest.cpp b/test/WorkQueueTest.cpp index 234832a..2dc7f73 100644 --- a/test/WorkQueueTest.cpp +++ b/test/WorkQueueTest.cpp @@ -97,6 +97,33 @@ TEST(WorkQueueTest, checkDynamicallyAddingTasks) { EXPECT_EQ(55, result); } +// Similar checkDynamicallyAddingTasks, but do much more work. +TEST(WorkQueueTest, stress) { + for (size_t num_threads = 8; num_threads <= 128; num_threads *= 2) { + std::atomic result{0}; + auto wq = sparta::work_queue( + [&](sparta::WorkerState* worker_state, int a) { + if (a > 0) { + worker_state->push_task(a - 1); + result++; + } + }, + num_threads, + /*push_tasks_while_running=*/true); + const size_t N = 200; + for (size_t i = 0; i <= N; i++) { + wq.add_item(10 * i); + } + wq.run_all(); + + // 10 * N + 10 * (N - 1) + ... + 10 + // = 10 * (N + (N - 1) + ... + 1 + 0) + // = 10 * (N * (N + 1) / 2) + // = 201000 // for N = 200 + EXPECT_EQ(201000, result); + } +} + TEST(WorkQueueTest, preciseScheduling) { std::array array = {0};