Skip to content

Commit

Permalink
Add a new thread pool that switches between single- and multi- thread…
Browse files Browse the repository at this point in the history
…ed execution (#73)

A thread pool that dynamically switches between single-threaded and
multi-threaded execution.
- If `n == 1`, the task will be executed on the main thread without any
locking mechanism that exists in `NativeThreadPool`.
- For `n > 1`, the tasks will be delegated to the internal
`NativeThreadPool` for parallel execution.
  • Loading branch information
dian-lun-lin authored Jan 24, 2025
1 parent 1b0ba34 commit fc44f44
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 10 deletions.
1 change: 0 additions & 1 deletion include/svs/index/vamana/dynamic_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ class MutableVamanaIndex {
/// mutating the graph.
void set_full_search_history(bool enable) { use_full_search_history_ = enable; }


///// Index translation.

///
Expand Down
46 changes: 37 additions & 9 deletions include/svs/lib/threads/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,17 @@ concept ThreadPool = requires(Pool& pool, const Pool& const_pool, std::function<
// clang-format on

template <ThreadPool Pool, typename F> void parallel_for(Pool& pool, F&& f) {
pool.parallel_for(
thunks::wrap(ThreadCount{pool.size()}, f), pool.size()
); // Current partitioning methods will create n partitions where n equals to the number
// of threads. Delegate the wrapped function to threadpool
pool.parallel_for(thunks::wrap(ThreadCount{pool.size()}, f), pool.size());
}

// Current partitioning methods create n partitions where n equals the
// number of threads. Adjust the number of partitions to match the problem size
// if the problem size is smaller than the number of threads.
template <ThreadPool Pool, typename T, typename F>
void parallel_for(Pool& pool, T&& arg, F&& f) {
if (!arg.empty()) {
pool.parallel_for(
thunks::wrap(ThreadCount{pool.size()}, f, std::forward<T>(arg)), pool.size()
); // Current partitioning methods will create n partitions where n equals to the
// number of threads. Delegate the wrapped function to threadpool
size_t n = std::min(arg.size(), pool.size());
pool.parallel_for(thunks::wrap(ThreadCount{n}, f, std::forward<T>(arg)), n);
}
}

Expand Down Expand Up @@ -200,7 +198,6 @@ template <typename Builder> class NativeThreadPoolBase {
}
}

private:
void manage_exception_during_run(const std::string& thread_0_message = {}) {
auto message = std::string{};
auto inserter = std::back_inserter(message);
Expand Down Expand Up @@ -255,6 +252,37 @@ auto create_on_nodes(InterNUMAThreadPool& threadpool, F&& f)
}
#endif

/////
///// A thread pool that dynamically switches between single-threaded and multi-threaded
/// execution.
///// - If `n == 1`, the task will be executed on the main thread without any locking
/// mechanism.
///// - For `n > 1`, the tasks will be delegated to the internal `NativeThreadPool` for
/// parallel execution.
/////
class SwitchNativeThreadPool {
public:
SwitchNativeThreadPool(size_t num_threads)
: threadpool_{num_threads} {}

size_t size() const { return threadpool_.size(); }

void parallel_for(std::function<void(size_t)> f, size_t n) {
if (n == 1) {
try {
f(0);
} catch (const std::exception& error) {
threadpool_.manage_exception_during_run(error.what());
}
} else {
threadpool_.parallel_for(std::move(f), n);
}
}

private:
NativeThreadPool threadpool_;
};

/////
///// A handy refernce wrapper for situations where we only want to share a thread pool
/////
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/exhaustive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra
svs::threads::QueueThreadPoolWrapper>(
index, queries, test_dataset::groundtruth_cosine()
);
test_flat<
decltype(index),
decltype(queries),
decltype(test_dataset::groundtruth_cosine()),
svs::threads::SwitchNativeThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
}

CATCH_SECTION("Cosine With Different Thread Pools From Data") {
Expand Down Expand Up @@ -359,5 +366,12 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra
svs::threads::DefaultThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
test_flat<
decltype(index),
decltype(queries),
decltype(test_dataset::groundtruth_cosine()),
svs::threads::SwitchNativeThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
}
}
5 changes: 5 additions & 0 deletions tests/integration/vamana/index_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,5 +339,10 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") {
run_tests<svs::threads::CppAsyncThreadPool>(
index, queries, groundtruth, expected_results.config_and_recall_
);

index.set_threadpool(svs::threads::SwitchNativeThreadPool(2));
run_tests<svs::threads::SwitchNativeThreadPool>(
index, queries, groundtruth, expected_results.config_and_recall_
);
}
}
110 changes: 110 additions & 0 deletions tests/svs/lib/threads/threadpool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,5 +271,115 @@ CATCH_TEST_CASE("Thread Pool", "[core][threads][threadpool]") {
return v == 2;
}));
}

/////
///// SwitchNativeThreadPool
/////
CATCH_SECTION("SwitchNativeThreadPool") {
auto pool = svs::threads::SwitchNativeThreadPool(num_threads);
start_time = std::chrono::steady_clock::now();
svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f);
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();
std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds"
<< std::endl;

start_time = std::chrono::steady_clock::now();
svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f);
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();
std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds"
<< std::endl;

CATCH_REQUIRE(std::all_of(v.begin(), v.end(), [](const uint64_t& v) {
return v == 2;
}));
}
}
// CATCH_SECTION("SwitchNativeThreadPool and NativeThreadPool with Parallel Calls") {
// constexpr size_t num_external_threads = 4;
// constexpr size_t num_internal_threads = 2;
// constexpr size_t num_elements = 50000000;

// auto start_time = std::chrono::steady_clock::now();
// auto stop_time = std::chrono::steady_clock::now();
// float time_seconds, switch_time_seconds;

//{
// std::vector<std::vector<size_t>> v;
// std::vector<size_t> sum(num_external_threads, 0);
// v.resize(num_external_threads);
// for (auto& vv : v) {
// vv.resize(num_elements, 1);
//}

// std::vector<std::thread> external_threads;
// auto pool = svs::threads::NativeThreadPool(num_internal_threads);
// start_time = std::chrono::steady_clock::now();

//// NativeThreadPool will block external parallelism due to internal lock.
// for (size_t i = 0; i < num_external_threads; ++i) {
// external_threads.emplace_back([&v, &pool, &sum, i]() {
// svs::threads::parallel_for(
// pool,
// svs::threads::StaticPartition{1},
//[i, &vv = v[i], &sum](const auto& [>unused*/, size_t /*unused<]) {
// for (auto val : vv) {
// sum[i] += val;
//}
//}
//);
//});
//}

// for (auto& t : external_threads) {
// t.join();
//}
// stop_time = std::chrono::steady_clock::now();
// time_seconds = std::chrono::duration<float>(stop_time - start_time).count();

// for (auto s : sum) {
// CATCH_REQUIRE(s == num_elements);
//}
//}

//{
// std::vector<std::vector<size_t>> v;
// std::vector<size_t> sum(num_external_threads, 0);
// v.resize(num_external_threads);
// for (auto& vv : v) {
// vv.resize(num_elements, 1);
//}

// std::vector<std::thread> external_threads;
// auto switch_pool = svs::threads::SwitchNativeThreadPool(num_internal_threads);
// start_time = std::chrono::steady_clock::now();

// for (size_t i = 0; i < num_external_threads; ++i) {
// external_threads.emplace_back([&v, &switch_pool, &sum, i]() {
// svs::threads::parallel_for(
// switch_pool,
// svs::threads::StaticPartition{1},
//[i, &vv = v[i], &sum](const auto& [>unused*/, size_t /*unused<]) {
// for (auto val : vv) {
// sum[i] += val;
//}
//}
//);
//});
//}

// for (auto& t : external_threads) {
// t.join();
//}
// stop_time = std::chrono::steady_clock::now();
// switch_time_seconds =
// std::chrono::duration<float>(stop_time - start_time).count();

// for (auto s : sum) {
// CATCH_REQUIRE(s == num_elements);
//}
//}
// CATCH_REQUIRE(switch_time_seconds < time_seconds);
//}
}

0 comments on commit fc44f44

Please sign in to comment.