diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index e59eeff..1026e1c 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -307,7 +307,6 @@ class MutableVamanaIndex { /// mutating the graph. void set_full_search_history(bool enable) { use_full_search_history_ = enable; } - ///// Index translation. /// diff --git a/include/svs/lib/threads/threadpool.h b/include/svs/lib/threads/threadpool.h index d2de115..31db19c 100644 --- a/include/svs/lib/threads/threadpool.h +++ b/include/svs/lib/threads/threadpool.h @@ -66,19 +66,17 @@ concept ThreadPool = requires(Pool& pool, const Pool& const_pool, std::function< // clang-format on template void parallel_for(Pool& pool, F&& f) { - pool.parallel_for( - thunks::wrap(ThreadCount{pool.size()}, f), pool.size() - ); // Current partitioning methods will create n partitions where n equals to the number - // of threads. Delegate the wrapped function to threadpool + pool.parallel_for(thunks::wrap(ThreadCount{pool.size()}, f), pool.size()); } +// Current partitioning methods create n partitions where n equals the +// number of threads. Adjust the number of partitions to match the problem size +// if the problem size is smaller than the number of threads. template void parallel_for(Pool& pool, T&& arg, F&& f) { if (!arg.empty()) { - pool.parallel_for( - thunks::wrap(ThreadCount{pool.size()}, f, std::forward(arg)), pool.size() - ); // Current partitioning methods will create n partitions where n equals to the - // number of threads. Delegate the wrapped function to threadpool + size_t n = std::min(arg.size(), pool.size()); + pool.parallel_for(thunks::wrap(ThreadCount{n}, f, std::forward(arg)), n); } } @@ -200,7 +198,6 @@ template class NativeThreadPoolBase { } } - private: void manage_exception_during_run(const std::string& thread_0_message = {}) { auto message = std::string{}; auto inserter = std::back_inserter(message); @@ -255,6 +252,37 @@ auto create_on_nodes(InterNUMAThreadPool& threadpool, F&& f) } #endif +///// +///// A thread pool that dynamically switches between single-threaded and multi-threaded +/// execution. +///// - If `n == 1`, the task will be executed on the main thread without any locking +/// mechanism. +///// - For `n > 1`, the tasks will be delegated to the internal `NativeThreadPool` for +/// parallel execution. +///// +class SwitchNativeThreadPool { + public: + SwitchNativeThreadPool(size_t num_threads) + : threadpool_{num_threads} {} + + size_t size() const { return threadpool_.size(); } + + void parallel_for(std::function f, size_t n) { + if (n == 1) { + try { + f(0); + } catch (const std::exception& error) { + threadpool_.manage_exception_during_run(error.what()); + } + } else { + threadpool_.parallel_for(std::move(f), n); + } + } + + private: + NativeThreadPool threadpool_; +}; + ///// ///// A handy refernce wrapper for situations where we only want to share a thread pool ///// diff --git a/tests/integration/exhaustive.cpp b/tests/integration/exhaustive.cpp index 3ae6535..4d08aee 100644 --- a/tests/integration/exhaustive.cpp +++ b/tests/integration/exhaustive.cpp @@ -327,6 +327,13 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra svs::threads::QueueThreadPoolWrapper>( index, queries, test_dataset::groundtruth_cosine() ); + test_flat< + decltype(index), + decltype(queries), + decltype(test_dataset::groundtruth_cosine()), + svs::threads::SwitchNativeThreadPool>( + index, queries, test_dataset::groundtruth_cosine() + ); } CATCH_SECTION("Cosine With Different Thread Pools From Data") { @@ -359,5 +366,12 @@ CATCH_TEST_CASE("Flat Orchestrator Search", "[integration][exhaustive][orchestra svs::threads::DefaultThreadPool>( index, queries, test_dataset::groundtruth_cosine() ); + test_flat< + decltype(index), + decltype(queries), + decltype(test_dataset::groundtruth_cosine()), + svs::threads::SwitchNativeThreadPool>( + index, queries, test_dataset::groundtruth_cosine() + ); } } diff --git a/tests/integration/vamana/index_search.cpp b/tests/integration/vamana/index_search.cpp index a683230..3b96d59 100644 --- a/tests/integration/vamana/index_search.cpp +++ b/tests/integration/vamana/index_search.cpp @@ -339,5 +339,10 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { run_tests( index, queries, groundtruth, expected_results.config_and_recall_ ); + + index.set_threadpool(svs::threads::SwitchNativeThreadPool(2)); + run_tests( + index, queries, groundtruth, expected_results.config_and_recall_ + ); } } diff --git a/tests/svs/lib/threads/threadpool.cpp b/tests/svs/lib/threads/threadpool.cpp index 0ddda0a..5d2a910 100644 --- a/tests/svs/lib/threads/threadpool.cpp +++ b/tests/svs/lib/threads/threadpool.cpp @@ -271,5 +271,115 @@ CATCH_TEST_CASE("Thread Pool", "[core][threads][threadpool]") { return v == 2; })); } + + ///// + ///// SwitchNativeThreadPool + ///// + CATCH_SECTION("SwitchNativeThreadPool") { + auto pool = svs::threads::SwitchNativeThreadPool(num_threads); + start_time = std::chrono::steady_clock::now(); + svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f); + stop_time = std::chrono::steady_clock::now(); + time_seconds = std::chrono::duration(stop_time - start_time).count(); + std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds" + << std::endl; + + start_time = std::chrono::steady_clock::now(); + svs::threads::parallel_for(pool, svs::threads::StaticPartition{v.size()}, f); + stop_time = std::chrono::steady_clock::now(); + time_seconds = std::chrono::duration(stop_time - start_time).count(); + std::cout << "SwitchNativeThreadPool: " << time_seconds << " seconds" + << std::endl; + + CATCH_REQUIRE(std::all_of(v.begin(), v.end(), [](const uint64_t& v) { + return v == 2; + })); + } } + // CATCH_SECTION("SwitchNativeThreadPool and NativeThreadPool with Parallel Calls") { + // constexpr size_t num_external_threads = 4; + // constexpr size_t num_internal_threads = 2; + // constexpr size_t num_elements = 50000000; + + // auto start_time = std::chrono::steady_clock::now(); + // auto stop_time = std::chrono::steady_clock::now(); + // float time_seconds, switch_time_seconds; + + //{ + // std::vector> v; + // std::vector sum(num_external_threads, 0); + // v.resize(num_external_threads); + // for (auto& vv : v) { + // vv.resize(num_elements, 1); + //} + + // std::vector external_threads; + // auto pool = svs::threads::NativeThreadPool(num_internal_threads); + // start_time = std::chrono::steady_clock::now(); + + //// NativeThreadPool will block external parallelism due to internal lock. + // for (size_t i = 0; i < num_external_threads; ++i) { + // external_threads.emplace_back([&v, &pool, &sum, i]() { + // svs::threads::parallel_for( + // pool, + // svs::threads::StaticPartition{1}, + //[i, &vv = v[i], &sum](const auto& [>unused*/, size_t /*unused<]) { + // for (auto val : vv) { + // sum[i] += val; + //} + //} + //); + //}); + //} + + // for (auto& t : external_threads) { + // t.join(); + //} + // stop_time = std::chrono::steady_clock::now(); + // time_seconds = std::chrono::duration(stop_time - start_time).count(); + + // for (auto s : sum) { + // CATCH_REQUIRE(s == num_elements); + //} + //} + + //{ + // std::vector> v; + // std::vector sum(num_external_threads, 0); + // v.resize(num_external_threads); + // for (auto& vv : v) { + // vv.resize(num_elements, 1); + //} + + // std::vector external_threads; + // auto switch_pool = svs::threads::SwitchNativeThreadPool(num_internal_threads); + // start_time = std::chrono::steady_clock::now(); + + // for (size_t i = 0; i < num_external_threads; ++i) { + // external_threads.emplace_back([&v, &switch_pool, &sum, i]() { + // svs::threads::parallel_for( + // switch_pool, + // svs::threads::StaticPartition{1}, + //[i, &vv = v[i], &sum](const auto& [>unused*/, size_t /*unused<]) { + // for (auto val : vv) { + // sum[i] += val; + //} + //} + //); + //}); + //} + + // for (auto& t : external_threads) { + // t.join(); + //} + // stop_time = std::chrono::steady_clock::now(); + // switch_time_seconds = + // std::chrono::duration(stop_time - start_time).count(); + + // for (auto s : sum) { + // CATCH_REQUIRE(s == num_elements); + //} + //} + // CATCH_REQUIRE(switch_time_seconds < time_seconds); + //} }