Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new thread pool that switches between single- and multi- threaded execution #73

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/svs/index/vamana/dynamic_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ class MutableVamanaIndex {
/// mutating the graph.
void set_full_search_history(bool enable) { use_full_search_history_ = enable; }


///// Index translation.

///
Expand Down
46 changes: 37 additions & 9 deletions include/svs/lib/threads/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,17 @@ concept ThreadPool = requires(Pool& pool, const Pool& const_pool, std::function<
// clang-format on

template <ThreadPool Pool, typename F> void parallel_for(Pool& pool, F&& f) {
pool.parallel_for(
thunks::wrap(ThreadCount{pool.size()}, f), pool.size()
); // Current partitioning methods will create n partitions where n equals to the number
// of threads. Delegate the wrapped function to threadpool
pool.parallel_for(thunks::wrap(ThreadCount{pool.size()}, f), pool.size());
}

// Current partitioning methods create n partitions where n equals the
// number of threads. Adjust the number of partitions to match the problem size
// if the problem size is smaller than the number of threads.
template <ThreadPool Pool, typename T, typename F>
void parallel_for(Pool& pool, T&& arg, F&& f) {
if (!arg.empty()) {
pool.parallel_for(
thunks::wrap(ThreadCount{pool.size()}, f, std::forward<T>(arg)), pool.size()
); // Current partitioning methods will create n partitions where n equals to the
// number of threads. Delegate the wrapped function to threadpool
size_t n = std::min(arg.size(), pool.size());
pool.parallel_for(thunks::wrap(ThreadCount{n}, f, std::forward<T>(arg)), n);
}
}

Expand Down Expand Up @@ -200,7 +198,6 @@ template <typename Builder> class NativeThreadPoolBase {
}
}

private:
void manage_exception_during_run(const std::string& thread_0_message = {}) {
auto message = std::string{};
auto inserter = std::back_inserter(message);
Expand Down Expand Up @@ -255,6 +252,37 @@ auto create_on_nodes(InterNUMAThreadPool& threadpool, F&& f)
}
#endif

/////
///// A thread pool that dynamically switches between single-threaded and multi-threaded
/// execution.
///// - If `n == 1`, the task will be executed on the main thread without any locking
/// mechanism.
///// - For `n > 1`, the tasks will be delegated to the internal `NativeThreadPool` for
/// parallel execution.
/////
class SwitchNativeThreadPool {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the NativeThreadPoolBase ::parallel_for() implementation, a task is executed in the current thread in case of n == 1.
Why we cannot just fix existing code rather than making one more class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default NativeThreadPool is designed to be thread-safe. We want to ensure that no external threads call parallel_for in parallel, even with n==1. That is, the lock in NativeThreadPool will be always at the beginning of parallel_for.

It would be better to have a separate implementation that focuses on this dynamic switching capability.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks. Looking to NativeThreadPool::size() code I did not realize that NativeThreadPool is thread-safe.

public:
SwitchNativeThreadPool(size_t num_threads)
: threadpool_{num_threads} {}

size_t size() const { return threadpool_.size(); }

void parallel_for(std::function<void(size_t)> f, size_t n) {
if (n == 1) {
try {
f(0);
} catch (const std::exception& error) {
threadpool_.manage_exception_during_run(error.what());
}
} else {
threadpool_.parallel_for(std::move(f), n);
}
}

private:
NativeThreadPool threadpool_;
};

/////
///// A handy refernce wrapper for situations where we only want to share a thread pool
/////
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/exhaustive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra
svs::threads::QueueThreadPoolWrapper>(
index, queries, test_dataset::groundtruth_cosine()
);
test_flat<
decltype(index),
decltype(queries),
decltype(test_dataset::groundtruth_cosine()),
svs::threads::SwitchNativeThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
}

CATCH_SECTION("Cosine With Different Thread Pools From Data") {
Expand Down Expand Up @@ -359,5 +366,12 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra
svs::threads::DefaultThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
test_flat<
decltype(index),
decltype(queries),
decltype(test_dataset::groundtruth_cosine()),
svs::threads::SwitchNativeThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
}
}
5 changes: 5 additions & 0 deletions tests/integration/vamana/index_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,5 +339,10 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") {
run_tests<svs::threads::CppAsyncThreadPool>(
index, queries, groundtruth, expected_results.config_and_recall_
);

index.set_threadpool(svs::threads::SwitchNativeThreadPool(2));
run_tests<svs::threads::SwitchNativeThreadPool>(
index, queries, groundtruth, expected_results.config_and_recall_
);
}
}
110 changes: 110 additions & 0 deletions tests/svs/lib/threads/threadpool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,5 +271,115 @@ CATCH_TEST_CASE("Thread Pool", "[core][threads][threadpool]") {
return v == 2;
}));
}

/////
///// SwitchNativeThreadPool
/////
CATCH_SECTION("SwitchNativeThreadPool") {
auto pool = svs::threads::SwitchNativeThreadPool(num_threads);
start_time = std::chrono::steady_clock::now();
svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f);
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();
std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds"
<< std::endl;

start_time = std::chrono::steady_clock::now();
svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f);
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();
std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds"
<< std::endl;

CATCH_REQUIRE(std::all_of(v.begin(), v.end(), [](const uint64_t& v) {
return v == 2;
}));
}
}
// CATCH_SECTION("SwitchNativeThreadPool and NativeThreadPool with Parallel Calls") {
// constexpr size_t num_external_threads = 4;
// constexpr size_t num_internal_threads = 2;
// constexpr size_t num_elements = 50000000;

// auto start_time = std::chrono::steady_clock::now();
// auto stop_time = std::chrono::steady_clock::now();
// float time_seconds, switch_time_seconds;

//{
// std::vector<std::vector<size_t>> v;
// std::vector<size_t> sum(num_external_threads, 0);
// v.resize(num_external_threads);
// for (auto& vv : v) {
// vv.resize(num_elements, 1);
//}

// std::vector<std::thread> external_threads;
// auto pool = svs::threads::NativeThreadPool(num_internal_threads);
// start_time = std::chrono::steady_clock::now();

//// NativeThreadPool will block external parallelism due to internal lock.
// for (size_t i = 0; i < num_external_threads; ++i) {
// external_threads.emplace_back([&v, &pool, &sum, i]() {
// svs::threads::parallel_for(
// pool,
// svs::threads::StaticPartition{1},
//[i, &vv = v[i], &sum](const auto& [>unused*/, size_t /*unused<]) {
// for (auto val : vv) {
// sum[i] += val;
//}
//}
//);
//});
//}

// for (auto& t : external_threads) {
// t.join();
//}
// stop_time = std::chrono::steady_clock::now();
// time_seconds = std::chrono::duration<float>(stop_time - start_time).count();

// for (auto s : sum) {
// CATCH_REQUIRE(s == num_elements);
//}
//}

//{
// std::vector<std::vector<size_t>> v;
// std::vector<size_t> sum(num_external_threads, 0);
// v.resize(num_external_threads);
// for (auto& vv : v) {
// vv.resize(num_elements, 1);
//}

// std::vector<std::thread> external_threads;
// auto switch_pool = svs::threads::SwitchNativeThreadPool(num_internal_threads);
// start_time = std::chrono::steady_clock::now();

// for (size_t i = 0; i < num_external_threads; ++i) {
// external_threads.emplace_back([&v, &switch_pool, &sum, i]() {
// svs::threads::parallel_for(
// switch_pool,
// svs::threads::StaticPartition{1},
//[i, &vv = v[i], &sum](const auto& [>unused*/, size_t /*unused<]) {
// for (auto val : vv) {
// sum[i] += val;
//}
//}
//);
//});
//}

// for (auto& t : external_threads) {
// t.join();
//}
// stop_time = std::chrono::steady_clock::now();
// switch_time_seconds =
// std::chrono::duration<float>(stop_time - start_time).count();

// for (auto s : sum) {
// CATCH_REQUIRE(s == num_elements);
//}
//}
// CATCH_REQUIRE(switch_time_seconds < time_seconds);
//}
}
Loading