Skip to content

Commit

Permalink
add a new thread pool that switches between single- and multi- thread…
Browse files Browse the repository at this point in the history
…ed execution
  • Loading branch information
dian-lun-lin committed Jan 22, 2025
1 parent 1b0ba34 commit 2ec9e2d
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 1 deletion.
25 changes: 25 additions & 0 deletions include/svs/lib/threads/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,31 @@ auto create_on_nodes(InterNUMAThreadPool& threadpool, F&& f)
}
#endif

/////
///// A thread pool that dynamically switches between single-threaded and multi-threaded execution.
///// - If `n == 1`, the task will be executed on the main thread without any locking mechanism.
///// - For `n > 1`, the tasks will be delegated to the internal `NativeThreadPool` for parallel execution.
/////
class SwitchNativeThreadPool {
public:
SwitchNativeThreadPool(size_t num_threads): threadpool_{num_threads} {
}

size_t size() const { return threadpool_.size(); }

void parallel_for(std::function<void(size_t)> f, size_t n) {
if(n == 1) {
f(0);
}
else {
threadpool_.parallel_for(std::move(f), n);
}
}

private:
NativeThreadPool threadpool_;
};

/////
///// A handy refernce wrapper for situations where we only want to share a thread pool
/////
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/exhaustive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra
svs::threads::QueueThreadPoolWrapper>(
index, queries, test_dataset::groundtruth_cosine()
);
test_flat<
decltype(index),
decltype(queries),
decltype(test_dataset::groundtruth_cosine()),
svs::threads::SwitchNativeThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
}

CATCH_SECTION("Cosine With Different Thread Pools From Data") {
Expand Down Expand Up @@ -359,5 +366,12 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra
svs::threads::DefaultThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
test_flat<
decltype(index),
decltype(queries),
decltype(test_dataset::groundtruth_cosine()),
svs::threads::SwitchNativeThreadPool>(
index, queries, test_dataset::groundtruth_cosine()
);
}
}
5 changes: 5 additions & 0 deletions tests/integration/vamana/index_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,5 +339,10 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") {
run_tests<svs::threads::CppAsyncThreadPool>(
index, queries, groundtruth, expected_results.config_and_recall_
);

index.set_threadpool(svs::threads::SwitchNativeThreadPool(2));
run_tests<svs::threads::SwitchNativeThreadPool>(
index, queries, groundtruth, expected_results.config_and_recall_
);
}
}
101 changes: 100 additions & 1 deletion tests/svs/lib/threads/threadpool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <atomic>
#include <chrono>
#include <memory>
// #include <omp.h>
//#include <omp.h>
#include <random>
#include <thread>
#include <tuple>
Expand Down Expand Up @@ -271,5 +271,104 @@ CATCH_TEST_CASE("Thread Pool", "[core][threads][threadpool]") {
return v == 2;
}));
}

/////
///// SwitchNativeThreadPool
/////
CATCH_SECTION("SwitchNativeThreadPool") {
auto pool = svs::threads::SwitchNativeThreadPool(num_threads);
start_time = std::chrono::steady_clock::now();
svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f);
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();
std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds" << std::endl;

start_time = std::chrono::steady_clock::now();
svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f);
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();
std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds" << std::endl;

CATCH_REQUIRE(std::all_of(v.begin(), v.end(), [](const uint64_t& v) {
return v == 2;
}));
}
}
CATCH_SECTION("SwitchNativeThreadPool and NativeThreadPool with Parallel Calls") {
constexpr size_t num_external_threads = 2;
constexpr size_t num_tasks = 1;

auto start_time = std::chrono::steady_clock::now();
auto stop_time = std::chrono::steady_clock::now();
float time_seconds, switch_time_seconds;
std::vector<std::vector<size_t>> v;

{
v.resize(num_external_threads);
for(auto& vv: v) {
vv.resize(10000000);
}

std::vector<std::thread> external_threads;
auto switch_pool = svs::threads::SwitchNativeThreadPool(1);
start_time = std::chrono::steady_clock::now();

for(size_t i = 0; i < num_external_threads; ++i) {
external_threads.emplace_back([&v, &switch_pool, i]() { switch_pool.parallel_for([&vv = v[i]](size_t n) {
CATCH_REQUIRE(n == 0);
for(auto& val: vv) {
val = 2;
}
}, num_tasks); });
}

for(auto& t: external_threads) {
t.join();
}
stop_time = std::chrono::steady_clock::now();
switch_time_seconds = std::chrono::duration<float>(stop_time - start_time).count();

for(auto& vv: v) {
CATCH_REQUIRE(std::all_of(vv.begin(), vv.end(), [](const size_t& v) {
return v == 2;
}));
}
}

v.clear();

{
v.resize(num_external_threads);
for(auto& vv: v) {
vv.resize(10000000);
}

std::vector<std::thread> external_threads;
auto pool = svs::threads::NativeThreadPool(1);
start_time = std::chrono::steady_clock::now();

// NativeThreadPool will block external parallelism due to internal lock.
for(size_t i = 0; i < num_external_threads; ++i) {
external_threads.emplace_back([&v, &pool, i]() { pool.parallel_for([&vv=v[i]](size_t n) {
CATCH_REQUIRE(n == 0);
for(auto& val: vv) {
val = 2;
}
}, num_tasks); });
}

for(auto& t: external_threads) {
t.join();
}
stop_time = std::chrono::steady_clock::now();
time_seconds = std::chrono::duration<float>(stop_time - start_time).count();

for(auto& vv: v) {
CATCH_REQUIRE(std::all_of(vv.begin(), vv.end(), [](const size_t& v) {
return v == 2;
}));
}
}
CATCH_REQUIRE(switch_time_seconds < time_seconds);
}
}

0 comments on commit 2ec9e2d

Please sign in to comment.