Skip to content

Commit

Permalink
Merge pull request #3 from DeveloperPaul123/feature/minor-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
DeveloperPaul123 authored Dec 29, 2021
2 parents 6d01b06 + b7490c3 commit c1b91fd
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 53 deletions.
87 changes: 42 additions & 45 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace dp {
return std::forward<T>(v);
}

// Bind F and args... into a nullary one-shot lambda. Lambda captures by value.
// bind F and parameter pack into a nullary one shot. Lambda captures by value.
template <typename... Args, typename F>
auto bind(F &&f, Args &&...args) {
return [f = decay_copy(std::forward<F>(f)),
Expand All @@ -28,63 +28,58 @@ namespace dp {
};
}

template <class Queue, class U = typename Queue::value_type>
concept is_valid_queue = requires(Queue q) {
{ q.empty() } -> std::convertible_to<bool>;
{ q.front() } -> std::convertible_to<U &>;
{ q.back() } -> std::convertible_to<U &>;
q.pop();
};

static_assert(detail::is_valid_queue<std::queue<int>>);
static_assert(detail::is_valid_queue<dp::thread_safe_queue<int>>);
} // namespace detail

template <template <class T> class Queue, typename FunctionType = std::function<void()>>
template <typename FunctionType = std::function<void()>>
requires std::invocable<FunctionType> &&
std::is_same_v<void, std::invoke_result_t<FunctionType>> &&
detail::is_valid_queue<Queue<FunctionType>>
class thread_pool_impl {
std::is_same_v<void, std::invoke_result_t<FunctionType>>
class thread_pool {
public:
thread_pool_impl(
const unsigned int &number_of_threads = std::thread::hardware_concurrency()) {
thread_pool(const unsigned int &number_of_threads = std::thread::hardware_concurrency())
: queues_(number_of_threads) {
for (std::size_t i = 0; i < number_of_threads; ++i) {
queues_.push_back(std::make_unique<task_pair>());
threads_.emplace_back([&, id = i](std::stop_token stop_tok) {
do {
// check if we have task
if (queues_[id]->tasks.empty()) {
if (queues_[id].tasks.empty()) {
// no tasks, so we wait instead of spinning
queues_[id]->semaphore.acquire();
queues_[id].semaphore.acquire();
}

// ensure we have a task before getting task
// since the dtor releases the semaphore as well
if (!queues_[id]->tasks.empty()) {
if (!queues_[id].tasks.empty()) {
// get the task
auto &task = queues_[id]->tasks.front();
auto &task = queues_[id].tasks.front();
// invoke the task
std::invoke(std::move(task));
// decrement in-flight counter
--in_flight_;
// remove task from the queue
queues_[id]->tasks.pop();
queues_[id].tasks.pop();
}
} while (!stop_tok.stop_requested());
});
}
}

~thread_pool_impl() {
~thread_pool() {
// wait for tasks to complete first
do {
std::this_thread::yield();
} while (in_flight_ > 0);

// stop all threads
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i].request_stop();
queues_[i]->semaphore.release();
queues_[i].semaphore.release();
threads_[i].join();
}
}

/// thread pool is non-copyable
thread_pool_impl(const thread_pool_impl &) = delete;
thread_pool_impl &operator=(const thread_pool_impl &) = delete;
thread_pool(const thread_pool &) = delete;
thread_pool &operator=(const thread_pool &) = delete;

/**
* @brief Enqueue a task into the thread pool that returns a result.
Expand All @@ -98,11 +93,21 @@ namespace dp {
template <typename Function, typename... Args,
typename ReturnType = std::invoke_result_t<Function &&, Args &&...>>
requires std::invocable<Function, Args...>
[[nodiscard]] std::future<ReturnType> enqueue(Function &&f, Args &&...args) {
// use shared promise here so that we don't break the promise later
[[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args) {
/*
* use shared promise here so that we don't break the promise later (until C++23)
*
* with C++23 we can do the following:
*
* std::promise<ReturnType> promise;
* auto future = promise.get_future();
* auto task = [func = std::move(f), ... largs = std::move(args),
promise = std::move(promise)]() mutable {...};
*/
auto shared_promise = std::make_shared<std::promise<ReturnType>>();
auto task = [func = std::move(f), ... largs = std::move(args),
promise = shared_promise]() { promise->set_value(func(largs...)); };

// get the future before enqueuing the task
auto future = shared_promise->get_future();
// enqueue the task
Expand All @@ -125,33 +130,25 @@ namespace dp {
}

private:
using semaphore_type = std::binary_semaphore;
using task_type = FunctionType;
struct task_pair {
semaphore_type semaphore{0};
Queue<task_type> tasks{};
struct task_queue {
std::binary_semaphore semaphore{0};
dp::thread_safe_queue<FunctionType> tasks{};
};

template <typename Function>
void enqueue_task(Function &&f) {
const std::size_t i = count_++ % queues_.size();
queues_[i]->tasks.push(std::forward<Function>(f));
queues_[i]->semaphore.release();
++in_flight_;
queues_[i].tasks.push(std::forward<Function>(f));
queues_[i].semaphore.release();
}

std::vector<std::jthread> threads_;
// have to use unique_ptr here because std::binary_semaphore is not move/copy
// assignable/constructible
std::vector<std::unique_ptr<task_pair>> queues_;
std::deque<task_queue> queues_;
std::size_t count_ = 0;
std::atomic<int64_t> in_flight_{0};
};

/**
* @brief Thread pool class capable of queuing detached tasks and value returning tasks.
* @details This is a default alias for the dp::thread_pool_impl
*/
using thread_pool = thread_pool_impl<dp::thread_safe_queue>;

/**
* @example mandelbrot/source/main.cpp
* Example showing how to use thread pool with tasks that return a value. Outputs a PPM image of
Expand Down
16 changes: 10 additions & 6 deletions include/thread_pool/thread_safe_queue.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
#pragma once

#include <condition_variable>
#include <deque>
#include <mutex>
#include <queue>

namespace dp {
template <typename T>
class thread_safe_queue {
public:
using value_type = T;
using size_type = typename std::queue<T>::size_type;
using size_type = typename std::deque<T>::size_type;

thread_safe_queue() = default;

void push(T&& value) {
std::lock_guard lock(mutex_);
data_.push(std::forward<T>(value));
{
std::lock_guard lock(mutex_);
data_.push_back(std::forward<T>(value));
}
condition_variable_.notify_all();
}

bool empty() {
std::lock_guard lock(mutex_);
return data_.empty();
Expand All @@ -42,12 +46,12 @@ namespace dp {
void pop() {
std::unique_lock lock(mutex_);
condition_variable_.wait(lock, [this] { return !data_.empty(); });
data_.pop();
data_.pop_front();
}

private:
using mutex_type = std::mutex;
std::queue<T> data_;
std::deque<T> data_;
mutable mutex_type mutex_{};
std::condition_variable condition_variable_{};
};
Expand Down
21 changes: 19 additions & 2 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

#include <string>

TEST_CASE("Basic Return Types") {
TEST_CASE("Basic task return types") {
dp::thread_pool pool(2);
// TODO
auto future_value = pool.enqueue([](const int& value) { return value; }, 30);
auto future_negative = pool.enqueue([](int x) -> int { return x - 20; }, 3);

Expand All @@ -30,3 +29,21 @@ TEST_CASE("Ensure input params are properly passed") {
CHECK(j == futures[j].get());
}
}

TEST_CASE("Ensure work completes upon destruction") {
std::atomic<int> counter;
std::vector<std::future<int>> futures;
const auto total_tasks = 20;
{
dp::thread_pool pool(4);
for (auto i = 0; i < total_tasks; i++) {
auto task = [index = i, &counter]() {
counter++;
return index;
};
futures.push_back(pool.enqueue(task));
}
}

CHECK_EQ(counter.load(), total_tasks);
}

0 comments on commit c1b91fd

Please sign in to comment.