diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index f192d56e89..f82afb35e9 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -64,10 +64,17 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType } } -struct AlgoProperty { +class AlgoProperty { + public: + inline AlgoProperty() {} + inline AlgoProperty(MemoryType dataset_memory_type_, MemoryType query_memory_type_) + : dataset_memory_type(dataset_memory_type_), query_memory_type(query_memory_type_) + { + } MemoryType dataset_memory_type; // neighbors/distances should have same memory type as queries MemoryType query_memory_type; + virtual ~AlgoProperty() = default; }; class AnnBase { diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 50d3ef00b1..5c755108af 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -28,6 +28,7 @@ #include "thread_pool.hpp" #include +#include #include #include #include @@ -42,6 +43,7 @@ namespace raft::bench::ann { static inline std::unique_ptr current_algo{nullptr}; +static inline std::shared_ptr current_algo_props{nullptr}; using kv_series = std::vector>>; @@ -170,10 +172,13 @@ template void bench_search(::benchmark::State& state, Configuration::Index index, std::size_t search_param_ix, - std::shared_ptr> dataset) + std::shared_ptr> dataset, + Objective metric_objective) { - std::ptrdiff_t batch_offset = 0; - std::atomic queries_processed(0); + std::ptrdiff_t batch_offset = 0; + std::size_t queries_processed = 0; + + double total_time = 0; const auto& sp_json = index.search_params[search_param_ix]; @@ -190,381 +195,381 @@ void bench_search(::benchmark::State& state, throw std::runtime_error("Index file is missing. Run the benchmark in the build mode first."); return; } - // algo is static to cache it between close search runs to save time on index loading - static std::string index_file = ""; - if (index.file != index_file) { - current_algo.reset(); - index_file = index.file; - } - ANN* algo; // TODO: Just have one thread load this. - std::unique_ptr::AnnSearchParam> search_param; - try { - if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { - auto ualgo = ann::create_algo( - index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); - algo = ualgo.get(); - algo->load(index_file); - current_algo = std::move(ualgo); + /** + * Make sure the first thread loads the algo and dataset + */ + if (state.thread_index() == 0) { + // algo is static to cache it between close search runs to save time on index loading + static std::string index_file = ""; + if (index.file != index_file) { + current_algo.reset(); + index_file = index.file; } - search_param = ann::create_search_param(index.algo, sp_json); - // search_param->metric_objective = metric_objective; - } catch (const std::exception& e) { - throw std::runtime_error("Failed to create an algo: " + std::string(e.what())); - } - algo->set_search_param(*search_param); - const auto algo_property = parse_algo_property(algo->get_preference(), sp_json); - const T* query_set = dataset->query_set(algo_property.query_memory_type); - - // TODO: Have 1 thread create and load these. - buf distances{algo_property.query_memory_type, k * query_set_size}; - buf neighbors{algo_property.query_memory_type, k * query_set_size}; - - if (search_param->needs_dataset()) { + std::unique_ptr::AnnSearchParam> search_param; + ANN* algo; try { - algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type), - dataset->base_set_size()); - } catch (const std::exception& ex) { - throw std::runtime_error("The algorithm '" + index.name + - "' requires the base set, but it's not available. " + - "Exception: " + std::string(ex.what())); - return; + if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { + auto ualgo = ann::create_algo( + index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); + algo = ualgo.get(); + algo->load(index_file); + current_algo = std::move(ualgo); + } + search_param = ann::create_search_param(index.algo, sp_json); + search_param->metric_objective = metric_objective; + } catch (const std::exception& e) { + state.SkipWithError("Failed to create an algo: " + std::string(e.what())); + } + algo->set_search_param(*search_param); + auto algo_property = parse_algo_property(algo->get_preference(), sp_json); + current_algo_props = std::make_shared(algo_property.dataset_memory_type, + algo_property.query_memory_type); + if (search_param->needs_dataset()) { + try { + algo->set_search_dataset(dataset->base_set(current_algo_props->dataset_memory_type), + dataset->base_set_size()); + } catch (const std::exception& ex) { + state.SkipWithError("The algorithm '" + index.name + + "' requires the base set, but it's not available. " + + "Exception: " + std::string(ex.what())); + return; + } } } + const auto algo_property = *current_algo_props; + const T* query_set = dataset->query_set(algo_property.query_memory_type); + + /** + * Each thread will manage its own outputs + */ + std::shared_ptr> distances = + std::make_shared>(algo_property.query_memory_type, k * query_set_size); + std::shared_ptr> neighbors = + std::make_shared>(algo_property.query_memory_type, k * query_set_size); + cuda_timer gpu_timer; { - /** - * When the objective is throughput, we want to overlap batches - * as much as possible and measure the end-to-end time from start - * to finish. - * - * When the objective is latency, we want to measure each batch - * individually. Latency is better measured in single-query batches - * but larger batches are still allowed in this mode in order to - * compare against the resulting batch sizes in throughput mode. - */ - nvtx_case nvtx{state.name()}; - // Multithreading starts in the benchmark loop for (auto _ : state) { [[maybe_unused]] auto ntx_lap = nvtx.lap(); [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); + + ANN* algo = dynamic_cast*>(current_algo.get()); + auto start = std::chrono::high_resolution_clock::now(); // run the search try { algo->search(query_set + batch_offset * dataset->dim(), n_queries, k, - neighbors.data + batch_offset * k, - distances.data + batch_offset * k, + neighbors->data + batch_offset * k, + distances->data + batch_offset * k, gpu_timer.stream()); } catch (const std::exception& e) { state.SkipWithError(std::string(e.what())); } + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds = std::chrono::duration_cast>(end - start); // advance to the next batch batch_offset = (batch_offset + n_queries) % query_set_size; queries_processed += n_queries; + state.SetIterationTime(elapsed_seconds.count()); + total_time += elapsed_seconds.count(); } } - state.counters.insert({{"total_queries", queries_processed.load()}}); state.SetItemsProcessed(queries_processed); if (cudart.found()) { - state.counters.insert({{"GPU Time", gpu_timer.total_time() / state.iterations()}, - {"GPU QPS", queries_processed.load() / gpu_timer.total_time()}}); - } - if (state.skipped()) { return; } + state.counters.insert({{"GPU", gpu_timer.total_time() / double(state.iterations())}, - if (state.thread_index() == 0) { - state.counters.insert({{"k", k}, {"batch_size", n_queries}}); - // evaluate recall - if (dataset->max_k() >= k) { - const std::int32_t* gt = dataset->gt_set(); - const std::uint32_t max_k = dataset->max_k(); - buf neighbors_host = neighbors.move(MemoryType::Host); - - std::size_t rows = std::min(queries_processed.load(), query_set_size); - std::size_t match_count = 0; - std::size_t total_count = rows * static_cast(k); - for (std::size_t i = 0; i < rows; i++) { - for (std::uint32_t j = 0; j < k; j++) { - auto act_idx = std::int32_t(neighbors_host.data[i * k + j]); - for (std::uint32_t l = 0; l < k; l++) { - auto exp_idx = gt[i * max_k + l]; - if (act_idx == exp_idx) { - match_count++; - break; +// // Using gpu_timer.total_time() isn't really the most fair comparison. +// {"GPU QPS", queries_processed / total_time }}); + } + + // This will be the total number of queries across all threads + state.counters.insert({{"total_queries", queries_processed}}); + + if (state.skipped()) { + return; } + + if (state.thread_index() == 0) { + state.counters.insert({{"k", k}, {"n_queries", n_queries}}); + + // evaluate recall + if (dataset->max_k() >= k) { + const std::int32_t* gt = dataset->gt_set(); + const std::uint32_t max_k = dataset->max_k(); + buf neighbors_host = neighbors->move(MemoryType::Host); + std::size_t rows = std::min(queries_processed, query_set_size); + std::size_t match_count = 0; + std::size_t total_count = rows * static_cast(k); + for (std::size_t i = 0; i < rows; i++) { + for (std::uint32_t j = 0; j < k; j++) { + auto act_idx = std::int32_t(neighbors_host.data[i * k + j]); + for (std::uint32_t l = 0; l < k; l++) { + auto exp_idx = gt[i * max_k + l]; + if (act_idx == exp_idx) { + match_count++; + break; + } } } } + double actual_recall = static_cast(match_count) / static_cast(total_count); + state.counters.insert({{"Recall", actual_recall}}); } - double actual_recall = static_cast(match_count) / static_cast(total_count); - state.counters.insert({{"Recall", actual_recall}}); } } -} -inline void printf_usage() -{ - ::benchmark::PrintDefaultHelp(); - fprintf( - stdout, - " [--build|--search] \n" - " [--overwrite]\n" - " [--data_prefix=]\n" - " [--index_prefix=]\n" - " [--override_kv=]\n" - " [--mode=\n" - " .json\n" - "\n" - "Note the non-standard benchmark parameters:\n" - " --build: build mode, will build index\n" - " --search: search mode, will search using the built index\n" - " one and only one of --build and --search should be specified\n" - " --overwrite: force overwriting existing index files\n" - " --data_prefix=:" - " prepend to dataset file paths specified in the .json (default = 'data/').\n" - " --index_prefix=:" - " prepend to index file paths specified in the .json (default = 'index/').\n" - " --override_kv=:" - " override a build/search key one or more times multiplying the number of configurations;" - " you can use this parameter multiple times to get the Cartesian product of benchmark" - " configs.\n" - " --mode=" - " run the benchmarks in latency (accumulate times spent in each batch) or " - " throughput (pipeline batches and measure end-to-end) mode\n"); -} - -template -void register_build(std::shared_ptr> dataset, - std::vector indices, - bool force_overwrite) -{ - for (auto index : indices) { - auto suf = static_cast(index.build_param["override_suffix"]); - auto file_suf = suf; - index.build_param.erase("override_suffix"); - std::replace(file_suf.begin(), file_suf.end(), '/', '-'); - index.file += file_suf; - auto* b = ::benchmark::RegisterBenchmark( - index.name + suf, bench_build, dataset, index, force_overwrite); - b->Unit(benchmark::kSecond); - b->UseRealTime(); + inline void printf_usage() + { + ::benchmark::PrintDefaultHelp(); + fprintf( + stdout, + " [--build|--search] \n" + " [--overwrite]\n" + " [--data_prefix=]\n" + " [--index_prefix=]\n" + " [--override_kv=]\n" + " [--mode=\n" + " .json\n" + "\n" + "Note the non-standard benchmark parameters:\n" + " --build: build mode, will build index\n" + " --search: search mode, will search using the built index\n" + " one and only one of --build and --search should be specified\n" + " --overwrite: force overwriting existing index files\n" + " --data_prefix=:" + " prepend to dataset file paths specified in the .json (default = " + "'data/').\n" + " --index_prefix=:" + " prepend to index file paths specified in the .json (default = " + "'index/').\n" + " --override_kv=:" + " override a build/search key one or more times multiplying the number of configurations;" + " you can use this parameter multiple times to get the Cartesian product of benchmark" + " configs.\n" + " --mode=" + " run the benchmarks in latency (accumulate times spent in each batch) or " + " throughput (pipeline batches and measure end-to-end) mode\n"); } -} -template -void register_search(std::shared_ptr> dataset, - std::vector indices, - Objective metric_objective) -{ - for (auto index : indices) { - for (std::size_t i = 0; i < index.search_params.size(); i++) { - auto suf = static_cast(index.search_params[i]["override_suffix"]); - index.search_params[i].erase("override_suffix"); - - auto* b = ::benchmark::RegisterBenchmark(index.name + suf, bench_search, index, i, dataset) - ->Unit(benchmark::kMillisecond) - ->ThreadRange(1, 16) - ->UseRealTime(); - std::cout << "Done registering index " << i << std::endl; - } + template + void register_build(std::shared_ptr> dataset, + std::vector indices, + bool force_overwrite) + { + for (auto index : indices) { + auto suf = static_cast(index.build_param["override_suffix"]); + auto file_suf = suf; + index.build_param.erase("override_suffix"); + std::replace(file_suf.begin(), file_suf.end(), '/', '-'); + index.file += file_suf; + auto* b = ::benchmark::RegisterBenchmark( + index.name + suf, bench_build, dataset, index, force_overwrite); + b->Unit(benchmark::kSecond); + b->UseRealTime(); + } } -} -template -void dispatch_benchmark(const Configuration& conf, - bool force_overwrite, - bool build_mode, - bool search_mode, - std::string data_prefix, - std::string index_prefix, - kv_series override_kv, - Objective metric_objective) -{ - if (cudart.found()) { - for (auto [key, value] : cuda_info()) { - ::benchmark::AddCustomContext(key, value); - } - } - const auto dataset_conf = conf.get_dataset_conf(); - auto base_file = combine_path(data_prefix, dataset_conf.base_file); - auto query_file = combine_path(data_prefix, dataset_conf.query_file); - auto gt_file = dataset_conf.groundtruth_neighbors_file; - if (gt_file.has_value()) { gt_file.emplace(combine_path(data_prefix, gt_file.value())); } - auto dataset = std::make_shared>(dataset_conf.name, - base_file, - dataset_conf.subset_first_row, - dataset_conf.subset_size, - query_file, - dataset_conf.distance, - gt_file); - ::benchmark::AddCustomContext("dataset", dataset_conf.name); - ::benchmark::AddCustomContext("distance", dataset_conf.distance); - std::vector indices = conf.get_indices(); - if (build_mode) { - if (file_exists(base_file)) { - log_info("Using the dataset file '%s'", base_file.c_str()); - ::benchmark::AddCustomContext("n_records", std::to_string(dataset->base_set_size())); - ::benchmark::AddCustomContext("dim", std::to_string(dataset->dim())); - } else { - log_warn("Dataset file '%s' does not exist; benchmarking index building is impossible.", - base_file.c_str()); - } - std::vector more_indices{}; - for (auto& index : indices) { - for (auto param : apply_overrides(index.build_param, override_kv)) { - auto modified_index = index; - modified_index.build_param = param; - modified_index.file = combine_path(index_prefix, modified_index.file); - more_indices.push_back(modified_index); - } - } - register_build(dataset, more_indices, force_overwrite); - } else if (search_mode) { - if (file_exists(query_file)) { - log_info("Using the query file '%s'", query_file.c_str()); - ::benchmark::AddCustomContext("max_n_queries", std::to_string(dataset->query_set_size())); - ::benchmark::AddCustomContext("dim", std::to_string(dataset->dim())); - if (gt_file.has_value()) { - if (file_exists(*gt_file)) { - log_info("Using the ground truth file '%s'", gt_file->c_str()); - ::benchmark::AddCustomContext("max_k", std::to_string(dataset->max_k())); - } else { - log_warn("Ground truth file '%s' does not exist; the recall won't be reported.", - gt_file->c_str()); + template + void register_search(std::shared_ptr> dataset, + std::vector indices, + Objective metric_objective) + { + for (auto index : indices) { + for (std::size_t i = 0; i < index.search_params.size(); i++) { + auto suf = static_cast(index.search_params[i]["override_suffix"]); + index.search_params[i].erase("override_suffix"); + + auto* b = ::benchmark::RegisterBenchmark( + index.name + suf, bench_search, index, i, dataset, metric_objective) + ->Unit(benchmark::kMillisecond) + ->ThreadRange(1, 32) + ->UseManualTime(); + } } - } else { - log_warn( - "Ground truth file is not provided; the recall won't be reported. NB: use " - "the 'groundtruth_neighbors_file' alongside the 'query_file' key to specify the path to " - "the ground truth in your conf.json."); - } - } else { - log_warn("Query file '%s' does not exist; benchmarking search is impossible.", - query_file.c_str()); - } - for (auto& index : indices) { - index.search_params = apply_overrides(index.search_params, override_kv); - index.file = combine_path(index_prefix, index.file); - } - register_search(dataset, indices, metric_objective); } -} -inline auto parse_bool_flag(const char* arg, const char* pat, bool& result) -> bool -{ - if (strcmp(arg, pat) == 0) { - result = true; - return true; + template + void dispatch_benchmark(const Configuration& conf, + bool force_overwrite, + bool build_mode, + bool search_mode, + std::string data_prefix, + std::string index_prefix, + kv_series override_kv, + Objective metric_objective) + { + if (cudart.found()) { + for (auto [key, value] : cuda_info()) { + ::benchmark::AddCustomContext(key, value); + } + } + const auto dataset_conf = conf.get_dataset_conf(); + auto base_file = combine_path(data_prefix, dataset_conf.base_file); + auto query_file = combine_path(data_prefix, dataset_conf.query_file); + auto gt_file = dataset_conf.groundtruth_neighbors_file; + if (gt_file.has_value()) { gt_file.emplace(combine_path(data_prefix, gt_file.value())); } + auto dataset = std::make_shared>(dataset_conf.name, + base_file, + dataset_conf.subset_first_row, + dataset_conf.subset_size, + query_file, + dataset_conf.distance, + gt_file); + ::benchmark::AddCustomContext("dataset", dataset_conf.name); + ::benchmark::AddCustomContext("distance", dataset_conf.distance); + std::vector indices = conf.get_indices(); + if (build_mode) { + if (file_exists(base_file)) { + log_info("Using the dataset file '%s'", base_file.c_str()); + ::benchmark::AddCustomContext("n_records", std::to_string(dataset->base_set_size())); + ::benchmark::AddCustomContext("dim", std::to_string(dataset->dim())); + } else { + log_warn("Dataset file '%s' does not exist; benchmarking index building is impossible.", + base_file.c_str()); + } + std::vector more_indices{}; + for (auto& index : indices) { + for (auto param : apply_overrides(index.build_param, override_kv)) { + auto modified_index = index; + modified_index.build_param = param; + modified_index.file = combine_path(index_prefix, modified_index.file); + more_indices.push_back(modified_index); + } + } + register_build(dataset, more_indices, force_overwrite); + } else if (search_mode) { + if (file_exists(query_file)) { + log_info("Using the query file '%s'", query_file.c_str()); + ::benchmark::AddCustomContext("max_n_queries", + std::to_string(dataset->query_set_size())); + ::benchmark::AddCustomContext("dim", std::to_string(dataset->dim())); + if (gt_file.has_value()) { + if (file_exists(*gt_file)) { + log_info("Using the ground truth file '%s'", gt_file->c_str()); + ::benchmark::AddCustomContext("max_k", std::to_string(dataset->max_k())); + } else { + log_warn("Ground truth file '%s' does not exist; the recall won't be reported.", + gt_file->c_str()); + } + } else { + log_warn( + "Ground truth file is not provided; the recall won't be reported. NB: use " + "the 'groundtruth_neighbors_file' alongside the 'query_file' key to specify the " + "path to " + "the ground truth in your conf.json."); + } + } else { + log_warn("Query file '%s' does not exist; benchmarking search is impossible.", + query_file.c_str()); + } + for (auto& index : indices) { + index.search_params = apply_overrides(index.search_params, override_kv); + index.file = combine_path(index_prefix, index.file); + } + register_search(dataset, indices, metric_objective); + } } - return false; -} -inline auto parse_string_flag(const char* arg, const char* pat, std::string& result) -> bool -{ - auto n = strlen(pat); - if (strncmp(pat, arg, strlen(pat)) == 0) { - result = arg + n + 1; - return true; + inline auto parse_bool_flag(const char* arg, const char* pat, bool& result)->bool + { + if (strcmp(arg, pat) == 0) { + result = true; + return true; + } + return false; } - return false; -} -inline auto run_main(int argc, char** argv) -> int -{ - bool force_overwrite = false; - bool build_mode = false; - bool search_mode = false; - std::string data_prefix = "data"; - std::string index_prefix = "index"; - std::string new_override_kv = ""; - std::string mode = "latency"; - kv_series override_kv{}; - - char arg0_default[] = "benchmark"; // NOLINT - char* args_default = arg0_default; - if (!argv) { - argc = 1; - argv = &args_default; - } - if (argc == 1) { - printf_usage(); - return -1; + inline auto parse_string_flag(const char* arg, const char* pat, std::string& result)->bool + { + auto n = strlen(pat); + if (strncmp(pat, arg, strlen(pat)) == 0) { + result = arg + n + 1; + return true; + } + return false; } - char* conf_path = argv[--argc]; - std::ifstream conf_stream(conf_path); - - for (int i = 1; i < argc; i++) { - if (parse_bool_flag(argv[i], "--overwrite", force_overwrite) || - parse_bool_flag(argv[i], "--build", build_mode) || - parse_bool_flag(argv[i], "--search", search_mode) || - parse_string_flag(argv[i], "--data_prefix", data_prefix) || - parse_string_flag(argv[i], "--index_prefix", index_prefix) || - parse_string_flag(argv[i], "--mode", mode) || - parse_string_flag(argv[i], "--override_kv", new_override_kv)) { - if (!new_override_kv.empty()) { - auto kvv = split(new_override_kv, ':'); - auto key = kvv[0]; - std::vector vals{}; - for (std::size_t j = 1; j < kvv.size(); j++) { - vals.push_back(nlohmann::json::parse(kvv[j])); + inline auto run_main(int argc, char** argv)->int + { + bool force_overwrite = false; + bool build_mode = false; + bool search_mode = false; + std::string data_prefix = "data"; + std::string index_prefix = "index"; + std::string new_override_kv = ""; + std::string mode = "latency"; + kv_series override_kv{}; + + char arg0_default[] = "benchmark"; // NOLINT + char* args_default = arg0_default; + if (!argv) { + argc = 1; + argv = &args_default; + } + if (argc == 1) { + printf_usage(); + return -1; } - override_kv.emplace_back(key, vals); - new_override_kv = ""; - } - for (int j = i; j < argc - 1; j++) { - argv[j] = argv[j + 1]; - } - argc--; - i--; - } - } - Objective metric_objective = Objective::LATENCY; - if (mode == "throughput") { metric_objective = Objective::THROUGHPUT; } + char* conf_path = argv[--argc]; + std::ifstream conf_stream(conf_path); + + for (int i = 1; i < argc; i++) { + if (parse_bool_flag(argv[i], "--overwrite", force_overwrite) || + parse_bool_flag(argv[i], "--build", build_mode) || + parse_bool_flag(argv[i], "--search", search_mode) || + parse_string_flag(argv[i], "--data_prefix", data_prefix) || + parse_string_flag(argv[i], "--index_prefix", index_prefix) || + parse_string_flag(argv[i], "--mode", mode) || + parse_string_flag(argv[i], "--override_kv", new_override_kv)) { + if (!new_override_kv.empty()) { + auto kvv = split(new_override_kv, ':'); + auto key = kvv[0]; + std::vector vals{}; + for (std::size_t j = 1; j < kvv.size(); j++) { + vals.push_back(nlohmann::json::parse(kvv[j])); + } + override_kv.emplace_back(key, vals); + new_override_kv = ""; + } + for (int j = i; j < argc - 1; j++) { + argv[j] = argv[j + 1]; + } + argc--; + i--; + } + } - if (build_mode == search_mode) { - log_error("One and only one of --build and --search should be specified"); - printf_usage(); - return -1; - } + Objective metric_objective = Objective::LATENCY; + if (mode == "throughput") { metric_objective = Objective::THROUGHPUT; } - if (!conf_stream) { - log_error("Can't open configuration file: %s", conf_path); - return -1; - } + if (build_mode == search_mode) { + log_error("One and only one of --build and --search should be specified"); + printf_usage(); + return -1; + } - if (cudart.needed() && !cudart.found()) { - log_warn("cudart library is not found, GPU-based indices won't work."); - } + if (!conf_stream) { + log_error("Can't open configuration file: %s", conf_path); + return -1; + } + + if (cudart.needed() && !cudart.found()) { + log_warn("cudart library is not found, GPU-based indices won't work."); + } - Configuration conf(conf_stream); - std::string dtype = conf.get_dataset_conf().dtype; - - if (dtype == "float") { - dispatch_benchmark(conf, - force_overwrite, - build_mode, - search_mode, - data_prefix, - index_prefix, - override_kv, - metric_objective); - } else if (dtype == "uint8") { - dispatch_benchmark(conf, - force_overwrite, - build_mode, - search_mode, - data_prefix, - index_prefix, - override_kv, - metric_objective); - } else if (dtype == "int8") { - dispatch_benchmark(conf, + Configuration conf(conf_stream); + std::string dtype = conf.get_dataset_conf().dtype; + + if (dtype == "float") { + dispatch_benchmark(conf, force_overwrite, build_mode, search_mode, @@ -572,19 +577,37 @@ inline auto run_main(int argc, char** argv) -> int index_prefix, override_kv, metric_objective); - } else { - log_error("datatype '%s' is not supported", dtype.c_str()); - return -1; - } + } else if (dtype == "uint8") { + dispatch_benchmark(conf, + force_overwrite, + build_mode, + search_mode, + data_prefix, + index_prefix, + override_kv, + metric_objective); + } else if (dtype == "int8") { + dispatch_benchmark(conf, + force_overwrite, + build_mode, + search_mode, + data_prefix, + index_prefix, + override_kv, + metric_objective); + } else { + log_error("datatype '%s' is not supported", dtype.c_str()); + return -1; + } - ::benchmark::Initialize(&argc, argv, printf_usage); - if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return -1; - ::benchmark::RunSpecifiedBenchmarks(); - ::benchmark::Shutdown(); - // Release a possibly cached ANN object, so that it cannot be alive longer than the handle to a - // shared library it depends on (dynamic benchmark executable). - current_algo.reset(); - return 0; -} + ::benchmark::Initialize(&argc, argv, printf_usage); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return -1; + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + // Release a possibly cached ANN object, so that it cannot be alive longer than the handle + // to a shared library it depends on (dynamic benchmark executable). + current_algo.reset(); + return 0; + } }; // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 81c50f2c2e..66776ee3bc 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -65,7 +65,7 @@ class HnswLib : public ANN { using typename ANN::AnnSearchParam; struct SearchParam : public AnnSearchParam { int ef; - int num_threads = omp_get_num_procs(); + int num_threads = 1; }; HnswLib(Metric metric, int dim, const BuildParam& param);