From c245530b5471c34dbd0d9dfe5715b7a9e3acd0d7 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Thu, 16 Jan 2025 20:55:21 +0800 Subject: [PATCH 1/8] introduce check id exist for index (#338) Signed-off-by: LHT129 --- include/vsag/index.h | 14 +++++++++++- include/vsag/index_features.h | 2 ++ src/algorithm/hgraph.cpp | 37 +++++++++++++++++++------------ src/algorithm/hgraph.h | 7 +++++- src/index/brute_force.cpp | 41 +++++++++++++++++++++++++---------- src/index/brute_force.h | 3 +++ src/index/hgraph_index.cpp | 1 + src/index/hgraph_index.h | 13 +++++++---- src/index/hnsw.cpp | 5 ++++- src/index/hnsw.h | 5 +++++ tests/test_brute_force.cpp | 15 +++++++++++++ tests/test_hgraph.cpp | 15 +++++++++++++ tests/test_hnsw_new.cpp | 9 ++++++++ tests/test_index.cpp | 26 ++++++++++++++++++++++ tests/test_index.h | 3 +++ 15 files changed, 163 insertions(+), 33 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index 96812420..128eab54 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -341,7 +341,7 @@ class Index { * @return number of bytes estimate used. */ [[nodiscard]] virtual uint64_t - EstimateMemory(const uint64_t num_elements) const { + EstimateMemory(uint64_t num_elements) const { throw std::runtime_error("Index not support estimate the memory by element counts"); } @@ -366,6 +366,18 @@ class Index { throw std::runtime_error("Index not support range search"); } + /** + * @brief Check if a specific ID exists in the index. + * + * @param id The ID to check for existence in the index. + * @return True if the ID exists, otherwise false. + * @throws std::runtime_error if the index does not support checking ID existence. + */ + [[nodiscard]] virtual bool + CheckIdExist(int64_t id) const { + throw std::runtime_error("Index not support check id exist"); + } + public: virtual ~Index() = default; }; diff --git a/include/vsag/index_features.h b/include/vsag/index_features.h index 32be1571..5042f03b 100644 --- a/include/vsag/index_features.h +++ b/include/vsag/index_features.h @@ -59,6 +59,8 @@ enum IndexFeature { SUPPORT_ESTIMATE_MEMORY, /**< Supports estimate memory usage by data count */ + SUPPORT_CHECK_ID_EXIST, /**< Supports check whether given id exists in index */ + INDEX_FEATURE_COUNT /** must be last one */ }; } // namespace vsag diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index ad3740d2..3a22325e 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -185,7 +185,7 @@ HGraph::KnnSearch(const DatasetPtr& query, } uint64_t -HGraph::EstimateMemory(const uint64_t num_elements) const { +HGraph::EstimateMemory(uint64_t num_elements) const { uint64_t estimate_memory = 0; auto block_size = Options::Instance().block_size_limit(); auto element_count = @@ -853,24 +853,33 @@ void HGraph::init_features() { // Common Init // Build & Add - feature_list_.SetFeatures({IndexFeature::SUPPORT_BUILD, - IndexFeature::SUPPORT_BUILD_WITH_MULTI_THREAD, - IndexFeature::SUPPORT_ADD_AFTER_BUILD}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_BUILD, + IndexFeature::SUPPORT_BUILD_WITH_MULTI_THREAD, + IndexFeature::SUPPORT_ADD_AFTER_BUILD, + }); // search - feature_list_.SetFeatures({IndexFeature::SUPPORT_KNN_SEARCH, - IndexFeature::SUPPORT_RANGE_SEARCH, - IndexFeature::SUPPORT_KNN_SEARCH_WITH_ID_FILTER, - IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_KNN_SEARCH, + IndexFeature::SUPPORT_RANGE_SEARCH, + IndexFeature::SUPPORT_KNN_SEARCH_WITH_ID_FILTER, + IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER, + }); // concurrency feature_list_.SetFeature(IndexFeature::SUPPORT_SEARCH_CONCURRENT); // serialize - feature_list_.SetFeatures({IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET, - IndexFeature::SUPPORT_DESERIALIZE_FILE, - IndexFeature::SUPPORT_DESERIALIZE_READER_SET, - IndexFeature::SUPPORT_SERIALIZE_BINARY_SET, - IndexFeature::SUPPORT_SERIALIZE_FILE}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET, + IndexFeature::SUPPORT_DESERIALIZE_FILE, + IndexFeature::SUPPORT_DESERIALIZE_READER_SET, + IndexFeature::SUPPORT_SERIALIZE_BINARY_SET, + IndexFeature::SUPPORT_SERIALIZE_FILE, + }); // other - feature_list_.SetFeatures({IndexFeature::SUPPORT_ESTIMATE_MEMORY}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_ESTIMATE_MEMORY, + IndexFeature::SUPPORT_CHECK_ID_EXIST, + }); // About Train auto name = this->basic_flatten_codes_->GetQuantizerName(); diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index ed44840d..af6e36f7 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -84,7 +84,7 @@ class HGraph { } uint64_t - EstimateMemory(const uint64_t num_elements) const; + EstimateMemory(uint64_t num_elements) const; // TODO(LHT): implement inline int64_t @@ -98,6 +98,11 @@ class HGraph { bool CheckFeature(IndexFeature feature) const; + bool + CheckIdExist(LabelType id) const { + return this->label_lookup_.find(id) != this->label_lookup_.end(); + } + inline void SetBuildThreadsCount(uint64_t count) { this->build_thread_count_ = count; diff --git a/src/index/brute_force.cpp b/src/index/brute_force.cpp index 88e4cec8..d6ae1264 100644 --- a/src/index/brute_force.cpp +++ b/src/index/brute_force.cpp @@ -292,28 +292,45 @@ BruteForce::init_feature_list() { if (name != QUANTIZATION_TYPE_VALUE_FP32) { feature_list_.SetFeature(IndexFeature::NEED_TRAIN); } else { - feature_list_.SetFeatures({IndexFeature::SUPPORT_ADD_FROM_EMPTY, - IndexFeature::SUPPORT_RANGE_SEARCH, - IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID, - IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_ADD_FROM_EMPTY, + IndexFeature::SUPPORT_RANGE_SEARCH, + IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID, + IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER, + }); } // Add & Build - feature_list_.SetFeatures({IndexFeature::SUPPORT_BUILD, IndexFeature::SUPPORT_ADD_AFTER_BUILD}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_BUILD, + IndexFeature::SUPPORT_ADD_AFTER_BUILD, + }); // Search feature_list_.SetFeatures({ IndexFeature::SUPPORT_KNN_SEARCH, - IndexFeature::SUPPORT_KNN_SEARCH_WITH_ID_FILTER, }); // concurrency - feature_list_.SetFeatures({IndexFeature::SUPPORT_SEARCH_CONCURRENT}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_SEARCH_CONCURRENT, + }); // serialize - feature_list_.SetFeatures({IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET, - IndexFeature::SUPPORT_DESERIALIZE_FILE, - IndexFeature::SUPPORT_DESERIALIZE_READER_SET, - IndexFeature::SUPPORT_SERIALIZE_BINARY_SET, - IndexFeature::SUPPORT_SERIALIZE_FILE}); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET, + IndexFeature::SUPPORT_DESERIALIZE_FILE, + IndexFeature::SUPPORT_DESERIALIZE_READER_SET, + IndexFeature::SUPPORT_SERIALIZE_BINARY_SET, + IndexFeature::SUPPORT_SERIALIZE_FILE, + }); + // others + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_ESTIMATE_MEMORY, + IndexFeature::SUPPORT_CHECK_ID_EXIST, + }); +} +bool +BruteForce::CheckIdExist(int64_t id) const { + return this->label_table_->CheckLabel(id); } } // namespace vsag diff --git a/src/index/brute_force.h b/src/index/brute_force.h index 9e29844a..f1a920c2 100644 --- a/src/index/brute_force.h +++ b/src/index/brute_force.h @@ -140,6 +140,9 @@ class BruteForce : public Index { [[nodiscard]] bool CheckFeature(IndexFeature feature) const override; + [[nodiscard]] bool + CheckIdExist(int64_t id) const override; + private: std::vector build(const DatasetPtr& data); diff --git a/src/index/hgraph_index.cpp b/src/index/hgraph_index.cpp index c604518e..6b9341ef 100644 --- a/src/index/hgraph_index.cpp +++ b/src/index/hgraph_index.cpp @@ -25,4 +25,5 @@ HGraphIndex::~HGraphIndex() { this->hgraph_.reset(); this->allocator_.reset(); } + } // namespace vsag diff --git a/src/index/hgraph_index.h b/src/index/hgraph_index.h index 9295eb93..39b6613d 100644 --- a/src/index/hgraph_index.h +++ b/src/index/hgraph_index.h @@ -122,26 +122,31 @@ class HGraphIndex : public Index { SAFE_CALL(return this->hgraph_->Deserialize(reader_set)); } - int64_t + [[nodiscard]] int64_t GetNumElements() const override { return this->hgraph_->GetNumElements(); } - int64_t + [[nodiscard]] int64_t GetMemoryUsage() const override { return this->hgraph_->GetMemoryUsage(); } [[nodiscard]] uint64_t - EstimateMemory(const uint64_t num_elements) const override { + EstimateMemory(uint64_t num_elements) const override { return this->hgraph_->EstimateMemory(num_elements); } - bool + [[nodiscard]] bool CheckFeature(IndexFeature feature) const override { return this->hgraph_->CheckFeature(feature); } + [[nodiscard]] bool + CheckIdExist(int64_t id) const override { + return this->hgraph_->CheckIdExist(id); + } + private: std::unique_ptr hgraph_{nullptr}; diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index e08d0817..fbad63da 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -975,7 +975,10 @@ HNSW::init_feature_list() { IndexFeature::SUPPORT_SERIALIZE_BINARY_SET, IndexFeature::SUPPORT_SERIALIZE_FILE}); // other - feature_list_.SetFeature(IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID); + feature_list_.SetFeatures({ + IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID, + IndexFeature::SUPPORT_CHECK_ID_EXIST, + }); } } // namespace vsag diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 1a170ea8..f8e4a474 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -148,6 +148,11 @@ class HNSW : public Index { [[nodiscard]] bool CheckFeature(IndexFeature feature) const override; + [[nodiscard]] bool + CheckIdExist(int64_t id) const override { + return this->alg_hnsw_->isValidLabel(id); + } + public: tl::expected Serialize() const override { diff --git a/tests/test_brute_force.cpp b/tests/test_brute_force.cpp index ea237e68..6cc5c3c0 100644 --- a/tests/test_brute_force.cpp +++ b/tests/test_brute_force.cpp @@ -176,6 +176,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::BruteForceTestIndex, if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -213,6 +216,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::BruteForceTestIndex, if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -248,6 +254,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::BruteForceTestIndex, "BruteForce Add", "[ if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -285,6 +294,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::BruteForceTestIndex, if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -386,6 +398,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::BruteForceTestIndex, if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index a4aab697..c9b69823 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -263,6 +263,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -298,6 +301,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Build", "[ft][hg if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -333,6 +339,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgra if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -368,6 +377,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Concurrent Add", if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -465,6 +477,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build" if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { TestFilterSearch(index, dataset, search_param, recall, true); } + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } diff --git a/tests/test_hnsw_new.cpp b/tests/test_hnsw_new.cpp index 58ec376a..196faad6 100644 --- a/tests/test_hnsw_new.cpp +++ b/tests/test_hnsw_new.cpp @@ -240,6 +240,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Build", "[ft][hnsw]" TestRangeSearch(index, dataset, search_param, 0.99, 10, true); TestRangeSearch(index, dataset, search_param, 0.49, 5, true); TestFilterSearch(index, dataset, search_param, 0.99, true); + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -262,6 +265,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Add", "[ft][hnsw]") TestRangeSearch(index, dataset, search_param, 0.99, 10, true); TestRangeSearch(index, dataset, search_param, 0.49, 5, true); TestFilterSearch(index, dataset, search_param, 0.99, true); + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } vsag::Options::Instance().set_block_size_limit(origin_size); } @@ -285,6 +291,9 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Concurrent Add", "[f TestRangeSearch(index, dataset, search_param, 0.95, 10, true); TestRangeSearch(index, dataset, search_param, 0.45, 5, true); TestFilterSearch(index, dataset, search_param, 0.95, true); + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + TestCheckIdExist(index, dataset); + } vsag::Options::Instance().set_block_size_limit(origin_size); } diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 85803240..ad79a474 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -104,6 +104,10 @@ TestIndex::TestUpdateId(const IndexPtr& index, auto succ_update_res = index->UpdateId(ids[i], update_id_map[ids[i]]); REQUIRE(succ_update_res.has_value()); if (expected_success) { + if (index->CheckFeature(vsag::IndexFeature::SUPPORT_CHECK_ID_EXIST)) { + REQUIRE(index->CheckIdExist(ids[i]) == false); + REQUIRE(index->CheckIdExist(update_id_map[ids[i]]) == true); + } REQUIRE(succ_update_res.value()); } @@ -676,4 +680,26 @@ TestIndex::TestEstimateMemory(const std::string& index_name, } } +void +TestIndex::TestCheckIdExist(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset) { + auto data_count = dataset->base_->GetNumElements(); + auto* ids = dataset->base_->GetIds(); + int N = 10; + for (int i = 0; i < N; ++i) { + auto good_id = ids[random() % data_count]; + REQUIRE(index->CheckIdExist(good_id) == true); + } + std::unordered_set exist_ids(ids, ids + data_count); + int bad_id = 97; + while (N > 0) { + for (; bad_id < data_count * N; ++bad_id) { + if (exist_ids.count(bad_id) == 0) { + break; + } + } + REQUIRE(index->CheckIdExist(bad_id) == false); + --N; + } +} + } // namespace fixtures diff --git a/tests/test_index.h b/tests/test_index.h index 1800634d..4bb846c7 100644 --- a/tests/test_index.h +++ b/tests/test_index.h @@ -150,6 +150,9 @@ class TestIndex { TestEstimateMemory(const std::string& index_name, const std::string& build_param, const TestDatasetPtr& dataset); + + static void + TestCheckIdExist(const IndexPtr& index, const TestDatasetPtr& dataset); }; } // namespace fixtures From 03dc3f64d69f3c873499225dc8a05133c664c908 Mon Sep 17 00:00:00 2001 From: ShawnShawnYou <58975154+ShawnShawnYou@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:28:42 +0800 Subject: [PATCH 2/8] fix asan issue in multi-threading (#335) Signed-off-by: zhongxiaoyao.zxy --- src/algorithm/hnswlib/algorithm_interface.h | 3 ++ src/algorithm/hnswlib/hnswalg.cpp | 13 +++++ src/algorithm/hnswlib/hnswalg.h | 4 ++ src/algorithm/hnswlib/hnswalg_static.h | 13 +++++ src/index/hnsw.cpp | 57 +++++++++++---------- src/index/hnsw_test.cpp | 13 +++++ tests/test_index_old.cpp | 2 +- tests/test_multi_thread.cpp | 41 ++++++++++----- 8 files changed, 105 insertions(+), 41 deletions(-) diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 6d6e5b7e..0bd2ae49 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -69,6 +69,9 @@ class AlgorithmInterface { virtual const float* getDataByLabel(LabelType label) const = 0; + virtual void + copyDataByLabel(LabelType label, void* data_point) = 0; + virtual std::priority_queue> bruteForce(const void* data_point, int64_t k) = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 578951a2..cadee13b 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -934,6 +934,19 @@ HierarchicalNSW::getDataByLabel(LabelType label) const { return data_ptr; } +void +HierarchicalNSW::copyDataByLabel(LabelType label, void* data_point) { + std::unique_lock lock_table(label_lookup_lock_); + + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + InnerIdType internal_id = search->second; + + memcpy(data_point, getDataByInternalId(internal_id), data_size_); +} + /* * Marks an element with the given label deleted, does NOT really change the current graph. */ diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 440bd6d6..1228c5ed 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -313,6 +313,10 @@ class HierarchicalNSW : public AlgorithmInterface { const float* getDataByLabel(LabelType label) const override; + + void + copyDataByLabel(LabelType label, void* data_point) override; + /* * Marks an element with the given label deleted, does NOT really change the current graph. */ diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 0e51d551..570e9373 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -262,6 +262,19 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } + void + copyDataByLabel(LabelType label, void* data_point) override { + std::unique_lock lock_table(label_lookup_lock); + + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + InnerIdType internal_id = search->second; + + memcpy(data_point, getDataByInternalId(internal_id), data_size_); + } + bool isValidLabel(LabelType label) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index fbad63da..98824afa 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -645,34 +645,33 @@ HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool force_update) { get_vectors(new_base, &new_base_vec, &data_size); if (not force_update) { - const void* base_data; + std::shared_ptr base_data(new int8_t[data_size]); auto base = Dataset::Make(); - // check if id exists - base_data = - std::reinterpret_pointer_cast(alg_hnsw_)->getDataByLabel( - id); - set_dataset(base, base_data, 1); + // check if id exists and get copied base data + std::reinterpret_pointer_cast(alg_hnsw_)->copyDataByLabel( + id, base_data.get()); + set_dataset(base, base_data.get(), 1); // search neighbors - auto result = this->knn_search(base, - vsag::UPDATE_CHECK_SEARCH_K, - fmt::format(R"({{ - "hnsw": - {{ - "ef_search": {} - }} - }})", - vsag::UPDATE_CHECK_SEARCH_L), - nullptr); + auto neighbors = *this->knn_search(base, + vsag::UPDATE_CHECK_SEARCH_K, + fmt::format(R"({{ + "hnsw": + {{ + "ef_search": {} + }} + }})", + vsag::UPDATE_CHECK_SEARCH_L), + nullptr); // check whether the neighborhood relationship is same float self_dist = std::reinterpret_pointer_cast(alg_hnsw_) ->getDistanceByLabel(id, new_base_vec); - for (int i = 0; i < result.value()->GetDim(); i++) { + for (int i = 0; i < neighbors->GetDim(); i++) { float neighbor_dist = std::reinterpret_pointer_cast(alg_hnsw_) - ->getDistanceByLabel(result.value()->GetIds()[i], new_base_vec); + ->getDistanceByLabel(neighbors->GetIds()[i], new_base_vec); if (neighbor_dist < self_dist) { return false; } @@ -836,8 +835,6 @@ HNSW::pretrain(const std::vector& base_tag_ids, uint32_t data_size = 0; uint32_t add_edges = 0; int64_t topk_neighbor_tag_id; - const void* topk_data; - const void* base_data; auto base = Dataset::Make(); auto generated_query = Dataset::Make(); if (type_ == DataTypes::DATA_TYPE_INT8) { @@ -845,14 +842,17 @@ HNSW::pretrain(const std::vector& base_tag_ids, } else { data_size = dim_ * 4; } + std::shared_ptr base_data(new int8_t[data_size]); + std::shared_ptr topk_data(new int8_t[data_size]); std::shared_ptr generated_data(new int8_t[data_size]); set_dataset(generated_query, generated_data.get(), 1); for (const int64_t& base_tag_id : base_tag_ids) { try { - base_data = (const void*)this->alg_hnsw_->getDataByLabel(base_tag_id); - set_dataset(base, base_data, 1); + std::reinterpret_pointer_cast(alg_hnsw_)->copyDataByLabel( + base_tag_id, base_data.get()); + set_dataset(base, base_data.get(), 1); } catch (const std::runtime_error& e) { LOG_ERROR_AND_RETURNS(ErrorType::INVALID_ARGUMENT, fmt::format("failed to pretrain(invalid argument): base tag id " @@ -877,17 +877,18 @@ HNSW::pretrain(const std::vector& base_tag_ids, if (topk_neighbor_tag_id == base_tag_id) { continue; } - topk_data = (const void*)this->alg_hnsw_->getDataByLabel(topk_neighbor_tag_id); + + std::reinterpret_pointer_cast(alg_hnsw_)->copyDataByLabel( + topk_neighbor_tag_id, topk_data.get()); for (int d = 0; d < dim_; d++) { if (type_ == DataTypes::DATA_TYPE_INT8) { - generated_data.get()[d] = - vsag::GENERATE_OMEGA * (float)(((int8_t*)base_data)[d]) + - (1 - vsag::GENERATE_OMEGA) * (float)(((int8_t*)topk_data)[d]); + generated_data.get()[d] = vsag::GENERATE_OMEGA * (float)(base_data[d]) + + (1 - vsag::GENERATE_OMEGA) * (float)(topk_data[d]); } else { ((float*)generated_data.get())[d] = - vsag::GENERATE_OMEGA * ((float*)base_data)[d] + - (1 - vsag::GENERATE_OMEGA) * ((float*)topk_data)[d]; + vsag::GENERATE_OMEGA * ((float*)base_data.get())[d] + + (1 - vsag::GENERATE_OMEGA) * ((float*)topk_data.get())[d]; } } diff --git a/src/index/hnsw_test.cpp b/src/index/hnsw_test.cpp index cdd0ce60..9b294c6b 100644 --- a/src/index/hnsw_test.cpp +++ b/src/index/hnsw_test.cpp @@ -902,10 +902,16 @@ TEST_CASE("get data by label", "[ut][hnsw]") { SECTION("hnsw test") { DefaultAllocator allocator; auto* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 100, &allocator); + std::shared_ptr base_data(new int8_t[dim * sizeof(float)]); alg_hnsw->init_memory_space(); alg_hnsw->addPoint(base_vectors.data(), 0); fixtures::dist_t distance = alg_hnsw->getDistanceByLabel(0, alg_hnsw->getDataByLabel(0)); + + alg_hnsw->copyDataByLabel(0, base_data.get()); + fixtures::dist_t distance_validate = alg_hnsw->getDistanceByLabel(0, base_data.get()); + REQUIRE(distance == 0); + REQUIRE(distance == distance_validate); REQUIRE_THROWS(alg_hnsw->getDistanceByLabel(-1, base_vectors.data())); delete alg_hnsw; } @@ -913,11 +919,18 @@ TEST_CASE("get data by label", "[ut][hnsw]") { SECTION("static hnsw test") { DefaultAllocator allocator; auto* alg_hnsw_static = new hnswlib::StaticHierarchicalNSW(&space, 100, &allocator); + std::shared_ptr base_data(new int8_t[dim * sizeof(float)]); alg_hnsw_static->init_memory_space(); alg_hnsw_static->addPoint(base_vectors.data(), 0); fixtures::dist_t distance = alg_hnsw_static->getDistanceByLabel(0, alg_hnsw_static->getDataByLabel(0)); + + alg_hnsw_static->copyDataByLabel(0, base_data.get()); + fixtures::dist_t distance_validate = + alg_hnsw_static->getDistanceByLabel(0, base_data.get()); + REQUIRE(distance == 0); + REQUIRE(distance == distance_validate); REQUIRE_THROWS(alg_hnsw_static->getDistanceByLabel(-1, base_vectors.data())); delete alg_hnsw_static; } diff --git a/tests/test_index_old.cpp b/tests/test_index_old.cpp index 0950468a..7165fa5e 100644 --- a/tests/test_index_old.cpp +++ b/tests/test_index_old.cpp @@ -1121,7 +1121,7 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { constexpr auto search_parameters_json = R"( {{ "hnsw": {{ - "ef_search": 100, + "ef_search": 50, "use_conjugate_graph_search": {} }} }} diff --git a/tests/test_multi_thread.cpp b/tests/test_multi_thread.cpp index dce603b5..b9351502 100644 --- a/tests/test_multi_thread.cpp +++ b/tests/test_multi_thread.cpp @@ -325,9 +325,10 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn std::vector> insert_results; std::vector> feedback_results; + std::vector> pretrain_results; std::vector> search_results; - for (int64_t i = 0; i < max_elements; ++i) { + for (int64_t i = 0; i < max_elements / 2; ++i) { // insert insert_results.push_back(pool.enqueue([&ids, &data, &index, dim, i]() -> int64_t { auto dataset = vsag::Dataset::Make(); @@ -341,7 +342,20 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn })); } - for (int64_t i = 0; i < max_elements; ++i) { + for (int64_t i = 0; i < max_elements / 2; ++i) { + // insert + int64_t insert_i = i + max_elements / 2; + insert_results.push_back(pool.enqueue([&ids, &data, &index, dim, insert_i]() -> int64_t { + auto dataset = vsag::Dataset::Make(); + dataset->Dim(dim) + ->NumElements(1) + ->Ids(ids.get() + insert_i) + ->Int8Vectors(data.get() + insert_i * dim) + ->Owner(false); + auto add_res = index->Add(dataset); + return add_res.value().size(); + })); + // feedback feedback_results.push_back( pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> uint64_t { @@ -351,6 +365,12 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn return feedback_res.value(); })); + // pretrain + pretrain_results.push_back(pool.enqueue([&index, &ids, i, k, str_parameters]() -> uint32_t { + auto pretrain_res = index->Pretrain({ids[i]}, k, str_parameters); + return pretrain_res.value(); + })); + // search search_results.push_back(pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> bool { auto query = vsag::Dataset::Make(); @@ -360,15 +380,12 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn })); } - for (auto& res : insert_results) { - REQUIRE(res.get() == 0); - } - - for (auto& res : feedback_results) { - REQUIRE(res.get() >= 0); - } - - for (auto& res : search_results) { - REQUIRE(res.get()); + for (int64_t i = 0; i < max_elements; ++i) { + REQUIRE(insert_results[i].get() == 0); + if (i < max_elements / 2) { + REQUIRE(pretrain_results[i].get() >= 0); + REQUIRE(feedback_results[i].get() >= 0); + REQUIRE(search_results[i].get() >= 0); + } } } From e587206d3b8ae9bb9ca8d9b8947364dc31a7b99d Mon Sep 17 00:00:00 2001 From: LHT129 Date: Fri, 17 Jan 2025 11:33:27 +0800 Subject: [PATCH 3/8] add prefetch implement (#334) Signed-off-by: LHT129 --- src/data_cell/flatten_datacell.h | 2 +- src/io/memory_block_io.h | 9 +-- src/io/memory_io.h | 9 +-- src/prefetch.cpp | 100 +++++++++++++++++++++++++++++++ src/prefetch.h | 32 ++++++++++ src/simd/basic_func.cpp | 11 ++++ src/simd/basic_func.h | 6 ++ src/simd/generic.cpp | 3 + src/simd/sse.cpp | 8 +++ 9 files changed, 165 insertions(+), 15 deletions(-) create mode 100644 src/prefetch.cpp create mode 100644 src/prefetch.h diff --git a/src/data_cell/flatten_datacell.h b/src/data_cell/flatten_datacell.h index 966e81c4..c8e2437e 100644 --- a/src/data_cell/flatten_datacell.h +++ b/src/data_cell/flatten_datacell.h @@ -67,7 +67,7 @@ class FlattenDataCell : public FlattenInterface { void Prefetch(InnerIdType id) override { - io_->Prefetch(id * code_size_); + io_->Prefetch(id * code_size_, code_size_); }; [[nodiscard]] std::string diff --git a/src/io/memory_block_io.h b/src/io/memory_block_io.h index c3f26364..c46fd9ca 100644 --- a/src/io/memory_block_io.h +++ b/src/io/memory_block_io.h @@ -15,10 +15,6 @@ #pragma once -#if defined(ENABLE_SSE) -#include //todo -#endif - #include #include #include @@ -31,6 +27,7 @@ #include "index/index_common_param.h" #include "inner_string_params.h" #include "memory_block_io_parameter.h" +#include "prefetch.h" #include "vsag/allocator.h" namespace vsag { @@ -194,9 +191,7 @@ MemoryBlockIO::MultiReadImpl(uint8_t* datas, } void MemoryBlockIO::PrefetchImpl(uint64_t offset, uint64_t cache_line) { -#if defined(ENABLE_SSE) - _mm_prefetch(get_data_ptr(offset), _MM_HINT_T0); // todo -#endif + PrefetchLines(get_data_ptr(offset), cache_line); } void diff --git a/src/io/memory_io.h b/src/io/memory_io.h index 0943751d..24ef1313 100644 --- a/src/io/memory_io.h +++ b/src/io/memory_io.h @@ -15,16 +15,13 @@ #pragma once -#if defined(ENABLE_SSE) -#include //todo -#endif - #include #include #include "basic_io.h" #include "index/index_common_param.h" #include "memory_io_parameter.h" +#include "prefetch.h" #include "vsag/allocator.h" namespace vsag { @@ -128,9 +125,7 @@ MemoryIO::MultiReadImpl(uint8_t* datas, uint64_t* sizes, uint64_t* offsets, uint } void MemoryIO::PrefetchImpl(uint64_t offset, uint64_t cache_line) { -#if defined(ENABLE_SSE) - _mm_prefetch(this->start_ + offset, _MM_HINT_T0); // todo -#endif + PrefetchLines(this->start_ + offset, cache_line); } void MemoryIO::SerializeImpl(StreamWriter& writer) { diff --git a/src/prefetch.cpp b/src/prefetch.cpp new file mode 100644 index 00000000..36eb48f1 --- /dev/null +++ b/src/prefetch.cpp @@ -0,0 +1,100 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "prefetch.h" +namespace vsag { + +#define PREFETCH_LINE(X) \ + case X: \ + PrefetchImpl(data); \ + break; + +template <> +void +PrefetchImpl<0>(const void* data){}; + +void +PrefetchLines(const void* data, uint64_t size) { + uint64_t n = std::min(size / 64, 63UL); + switch (n) { + PREFETCH_LINE(0); + PREFETCH_LINE(1); + PREFETCH_LINE(2); + PREFETCH_LINE(3); + PREFETCH_LINE(4); + PREFETCH_LINE(5); + PREFETCH_LINE(6); + PREFETCH_LINE(7); + PREFETCH_LINE(8); + PREFETCH_LINE(9); + PREFETCH_LINE(10); + PREFETCH_LINE(11); + PREFETCH_LINE(12); + PREFETCH_LINE(13); + PREFETCH_LINE(14); + PREFETCH_LINE(15); + PREFETCH_LINE(16); + PREFETCH_LINE(17); + PREFETCH_LINE(18); + PREFETCH_LINE(19); + PREFETCH_LINE(20); + PREFETCH_LINE(21); + PREFETCH_LINE(22); + PREFETCH_LINE(23); + PREFETCH_LINE(24); + PREFETCH_LINE(25); + PREFETCH_LINE(26); + PREFETCH_LINE(27); + PREFETCH_LINE(28); + PREFETCH_LINE(29); + PREFETCH_LINE(30); + PREFETCH_LINE(31); + PREFETCH_LINE(32); + PREFETCH_LINE(33); + PREFETCH_LINE(34); + PREFETCH_LINE(35); + PREFETCH_LINE(36); + PREFETCH_LINE(37); + PREFETCH_LINE(38); + PREFETCH_LINE(39); + PREFETCH_LINE(40); + PREFETCH_LINE(41); + PREFETCH_LINE(42); + PREFETCH_LINE(43); + PREFETCH_LINE(44); + PREFETCH_LINE(45); + PREFETCH_LINE(46); + PREFETCH_LINE(47); + PREFETCH_LINE(48); + PREFETCH_LINE(49); + PREFETCH_LINE(50); + PREFETCH_LINE(51); + PREFETCH_LINE(52); + PREFETCH_LINE(53); + PREFETCH_LINE(54); + PREFETCH_LINE(55); + PREFETCH_LINE(56); + PREFETCH_LINE(57); + PREFETCH_LINE(58); + PREFETCH_LINE(59); + PREFETCH_LINE(60); + PREFETCH_LINE(61); + PREFETCH_LINE(62); + PREFETCH_LINE(63); + default: + break; + } +} +} // namespace vsag diff --git a/src/prefetch.h b/src/prefetch.h new file mode 100644 index 00000000..5923c472 --- /dev/null +++ b/src/prefetch.h @@ -0,0 +1,32 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "simd/simd.h" + +namespace vsag { +template +void +PrefetchImpl(const void* data) { + Prefetch(data); + PrefetchImpl(static_cast(data) + 64); +} + +void +PrefetchLines(const void* data, uint64_t size); + +} // namespace vsag diff --git a/src/simd/basic_func.cpp b/src/simd/basic_func.cpp index 46bdf611..40e1f382 100644 --- a/src/simd/basic_func.cpp +++ b/src/simd/basic_func.cpp @@ -156,4 +156,15 @@ GetPQDistanceFloat256() { return generic::PQDistanceFloat256; } PQDistanceFunc PQDistanceFloat256 = GetPQDistanceFloat256(); + +static PrefetchFunc +GetPrefetch() { + if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::Prefetch; +#endif + } + return generic::Prefetch; +} +PrefetchFunc Prefetch = GetPrefetch(); } // namespace vsag diff --git a/src/simd/basic_func.h b/src/simd/basic_func.h index 7112bf2c..35e2401a 100644 --- a/src/simd/basic_func.h +++ b/src/simd/basic_func.h @@ -32,6 +32,8 @@ float INT8InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr); void PQDistanceFloat256(const void* single_dim_centers, float single_dim_val, void* result); +void +Prefetch(const void* data); } // namespace generic namespace sse { @@ -47,6 +49,8 @@ float INT8InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr); void PQDistanceFloat256(const void* single_dim_centers, float single_dim_val, void* result); +void +Prefetch(const void* data); } // namespace sse namespace avx { @@ -104,4 +108,6 @@ extern DistanceFuncType INT8InnerProductDistance; using PQDistanceFunc = void (*)(const void* single_dim_centers, float single_dim_val, void* result); extern PQDistanceFunc PQDistanceFloat256; +using PrefetchFunc = void (*)(const void* data); +extern PrefetchFunc Prefetch; } // namespace vsag diff --git a/src/simd/generic.cpp b/src/simd/generic.cpp index 3ef347f1..5feeca41 100644 --- a/src/simd/generic.cpp +++ b/src/simd/generic.cpp @@ -307,4 +307,7 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar) { } } +void +Prefetch(const void* data){}; + } // namespace vsag::generic diff --git a/src/simd/sse.cpp b/src/simd/sse.cpp index 7e22c57d..21c954d7 100644 --- a/src/simd/sse.cpp +++ b/src/simd/sse.cpp @@ -430,4 +430,12 @@ Normalize(const float* from, float* to, uint64_t dim) { sse::DivScalar(from, to, dim, norm); return norm; } + +void +Prefetch(const void* data) { +#if defined(ENABLE_SSE) + _mm_prefetch(data, _MM_HINT_T0); +#endif +}; + } // namespace vsag::sse From d9121cdf930093551b74715dfecd8421c329d013 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Fri, 17 Jan 2025 14:28:04 +0800 Subject: [PATCH 4/8] introduce new eval tool: eval_performance (#234) - the same as test_performance but support more - add Monitor, EvalCase to extends the eval - use argparse for command line Signed-off-by: LHT129 --- tools/CMakeLists.txt | 8 +- tools/eval/CMakeLists.txt | 14 ++ tools/eval/build_eval_case.cpp | 95 ++++++++++++ tools/eval/build_eval_case.h | 52 +++++++ tools/eval/eval_case.cpp | 53 +++++++ tools/eval/eval_case.h | 72 +++++++++ tools/eval/eval_config.cpp | 58 +++++++ tools/eval/eval_config.h | 50 +++++++ tools/{ => eval}/eval_dataset.cpp | 5 +- tools/{ => eval}/eval_dataset.h | 25 +++- tools/eval/monitor/duration_monitor.cpp | 39 +++++ tools/eval/monitor/duration_monitor.h | 49 ++++++ tools/eval/monitor/latency_monitor.cpp | 88 +++++++++++ tools/eval/monitor/latency_monitor.h | 66 ++++++++ tools/eval/monitor/memory_peak_monitor.cpp | 53 +++++++ tools/eval/monitor/memory_peak_monitor.h | 53 +++++++ tools/eval/monitor/monitor.cpp | 22 +++ tools/eval/monitor/monitor.h | 59 ++++++++ tools/eval/monitor/recall_monitor.cpp | 93 ++++++++++++ tools/eval/monitor/recall_monitor.h | 60 ++++++++ tools/eval/search_eval_case.cpp | 166 +++++++++++++++++++++ tools/eval/search_eval_case.h | 83 +++++++++++ tools/eval_performance.cpp | 101 +++++++++++++ tools/test_performance.cpp | 3 +- 24 files changed, 1359 insertions(+), 8 deletions(-) create mode 100644 tools/eval/CMakeLists.txt create mode 100644 tools/eval/build_eval_case.cpp create mode 100644 tools/eval/build_eval_case.h create mode 100644 tools/eval/eval_case.cpp create mode 100644 tools/eval/eval_case.h create mode 100644 tools/eval/eval_config.cpp create mode 100644 tools/eval/eval_config.h rename tools/{ => eval}/eval_dataset.cpp (98%) rename tools/{ => eval}/eval_dataset.h (86%) create mode 100644 tools/eval/monitor/duration_monitor.cpp create mode 100644 tools/eval/monitor/duration_monitor.h create mode 100644 tools/eval/monitor/latency_monitor.cpp create mode 100644 tools/eval/monitor/latency_monitor.h create mode 100644 tools/eval/monitor/memory_peak_monitor.cpp create mode 100644 tools/eval/monitor/memory_peak_monitor.h create mode 100644 tools/eval/monitor/monitor.cpp create mode 100644 tools/eval/monitor/monitor.h create mode 100644 tools/eval/monitor/recall_monitor.cpp create mode 100644 tools/eval/monitor/recall_monitor.h create mode 100644 tools/eval/search_eval_case.cpp create mode 100644 tools/eval/search_eval_case.h create mode 100644 tools/eval_performance.cpp diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 9d97940f..41b79aa4 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,6 +1,8 @@ -add_library (eval_dataset OBJECT eval_dataset.cpp eval_dataset.h) -add_dependencies (eval_dataset hdf5 spdlog) +add_subdirectory (eval) add_executable (test_performance test_performance.cpp) -target_link_libraries (test_performance PRIVATE vsag eval_dataset simd libhdf5_cpp.a libhdf5.a z) +target_link_libraries (test_performance PRIVATE vsag eval_obj simd libhdf5_cpp.a libhdf5.a z) + +add_executable (eval_performance eval_performance.cpp) +target_link_libraries (eval_performance PRIVATE vsag eval_obj argparse::argparse simd libhdf5_cpp.a libhdf5.a z) diff --git a/tools/eval/CMakeLists.txt b/tools/eval/CMakeLists.txt new file mode 100644 index 00000000..a79c8835 --- /dev/null +++ b/tools/eval/CMakeLists.txt @@ -0,0 +1,14 @@ + +set (EVAL_SRC + eval_case.cpp + search_eval_case.cpp + build_eval_case.cpp + eval_dataset.cpp + eval_config.cpp + monitor/monitor.cpp + monitor/latency_monitor.cpp + monitor/recall_monitor.cpp + monitor/memory_peak_monitor.cpp + monitor/duration_monitor.cpp) +add_library (eval_obj OBJECT ${EVAL_SRC}) +add_dependencies (eval_obj hdf5 spdlog) diff --git a/tools/eval/build_eval_case.cpp b/tools/eval/build_eval_case.cpp new file mode 100644 index 00000000..c690da5c --- /dev/null +++ b/tools/eval/build_eval_case.cpp @@ -0,0 +1,95 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "build_eval_case.h" + +#include +#include + +#include "monitor/duration_monitor.h" +#include "monitor/memory_peak_monitor.h" + +namespace vsag::eval { + +BuildEvalCase::BuildEvalCase(const std::string& dataset_path, + const std::string& index_path, + vsag::IndexPtr index, + EvalConfig config) + : EvalCase(dataset_path, index_path, index), config_(std::move(config)) { + this->init_monitors(); +} + +void +BuildEvalCase::init_monitors() { + if (config_.enable_memory) { + auto memory_peak_monitor = std::make_shared(); + this->monitors_.emplace_back(std::move(memory_peak_monitor)); + } + if (config_.enable_tps) { + auto duration_monitor = std::make_shared(); + this->monitors_.emplace_back(std::move(duration_monitor)); + } +} + +void +BuildEvalCase::Run() { + this->do_build(); + this->serialize(); + auto result = this->process_result(); + PrintResult(result); +} +void +BuildEvalCase::do_build() { + auto base = vsag::Dataset::Make(); + int64_t total_base = this->dataset_ptr_->GetNumberOfBase(); + std::vector ids(total_base); + std::iota(ids.begin(), ids.end(), 0); + base->NumElements(total_base)->Dim(this->dataset_ptr_->GetDim())->Ids(ids.data())->Owner(false); + if (this->dataset_ptr_->GetTrainDataType() == vsag::DATATYPE_FLOAT32) { + base->Float32Vectors((const float*)this->dataset_ptr_->GetTrain()); + } else if (this->dataset_ptr_->GetTrainDataType() == vsag::DATATYPE_INT8) { + base->Int8Vectors((const int8_t*)this->dataset_ptr_->GetTrain()); + } + for (auto& monitor : monitors_) { + monitor->Start(); + } + auto build_index = index_->Build(base); + for (auto& monitor : monitors_) { + monitor->Record(); + monitor->Stop(); + } +} +void +BuildEvalCase::serialize() { + std::ofstream outfile(this->index_path_, std::ios::binary); + this->index_->Serialize(outfile); +} + +EvalCase::JsonType +BuildEvalCase::process_result() { + JsonType result; + JsonType eval_result; + for (auto& monitor : this->monitors_) { + const auto& one_result = monitor->GetResult(); + EvalCase::MergeJsonType(one_result, eval_result); + } + result = eval_result; + result["tps"] = double(this->dataset_ptr_->GetNumberOfBase()) / double(result["duration(s)"]); + EvalCase::MergeJsonType(this->basic_info_, result); + result["index_info"] = JsonType::parse(config_.build_param); + return result; +} + +} // namespace vsag::eval diff --git a/tools/eval/build_eval_case.h b/tools/eval/build_eval_case.h new file mode 100644 index 00000000..656d3c2e --- /dev/null +++ b/tools/eval/build_eval_case.h @@ -0,0 +1,52 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "eval_case.h" +#include "monitor/monitor.h" +namespace vsag::eval { + +class BuildEvalCase : public EvalCase { +public: + BuildEvalCase(const std::string& dataset_path, + const std::string& index_path, + vsag::IndexPtr index, + EvalConfig config); + + ~BuildEvalCase() override = default; + + void + Run() override; + +private: + void + init_monitors(); + + void + do_build(); + + void + serialize(); + + JsonType + process_result(); + +private: + std::vector monitors_{}; + + EvalConfig config_; +}; +} // namespace vsag::eval diff --git a/tools/eval/eval_case.cpp b/tools/eval/eval_case.cpp new file mode 100644 index 00000000..513165a8 --- /dev/null +++ b/tools/eval/eval_case.cpp @@ -0,0 +1,53 @@ + + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval_case.h" + +#include + +#include "build_eval_case.h" +#include "search_eval_case.h" +#include "vsag/factory.h" +#include "vsag/options.h" + +namespace vsag::eval { + +EvalCase::EvalCase(std::string dataset_path, std::string index_path, vsag::IndexPtr index) + : dataset_path_(std::move(dataset_path)), index_path_(std::move(index_path)), index_(index) { + this->dataset_ptr_ = EvalDataset::Load(dataset_path_); + this->logger_ = vsag::Options::Instance().logger(); + this->basic_info_ = this->dataset_ptr_->GetInfo(); +} + +EvalCasePtr +EvalCase::MakeInstance(const EvalConfig& config) { + auto dataset_path = config.dataset_path; + auto index_path = config.index_path; + auto index_name = config.index_name; + auto create_params = config.build_param; + + auto index = vsag::Factory::CreateIndex(index_name, create_params); + + auto type = config.action_type; + if (type == "build") { + return std::make_shared(dataset_path, index_path, index.value(), config); + } else if (type == "search") { + return std::make_shared(dataset_path, index_path, index.value(), config); + } else { + return nullptr; + } +} +} // namespace vsag::eval diff --git a/tools/eval/eval_case.h b/tools/eval/eval_case.h new file mode 100644 index 00000000..a8d3279a --- /dev/null +++ b/tools/eval/eval_case.h @@ -0,0 +1,72 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "eval_config.h" +#include "eval_dataset.h" +#include "nlohmann/json.hpp" +#include "vsag/index.h" +#include "vsag/logger.h" + +namespace vsag::eval { + +class EvalCase; +using EvalCasePtr = std::shared_ptr; + +class EvalCase { +public: + using JsonType = nlohmann::json; + +public: + static EvalCasePtr + MakeInstance(const EvalConfig& parser); + + static void + MergeJsonType(const JsonType& input, JsonType& output) { + for (auto& [key, value] : input.items()) { + output[key] = value; + } + } + + static void + PrintResult(const JsonType& result) { + std::cout << result.dump(4) << std::endl; + } + +public: + explicit EvalCase(std::string dataset_path, std::string index_path, vsag::IndexPtr index); + + virtual ~EvalCase() = default; + + virtual void + Run() = 0; + + using Logger = vsag::Logger*; + +protected: + const std::string dataset_path_{}; + const std::string index_path_{}; + + EvalDatasetPtr dataset_ptr_{nullptr}; + + vsag::IndexPtr index_{nullptr}; + + Logger logger_{nullptr}; + + JsonType basic_info_{}; +}; + +} // namespace vsag::eval diff --git a/tools/eval/eval_config.cpp b/tools/eval/eval_config.cpp new file mode 100644 index 00000000..1eaaf665 --- /dev/null +++ b/tools/eval/eval_config.cpp @@ -0,0 +1,58 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval_config.h" + +namespace vsag::eval { +EvalConfig +EvalConfig::Load(argparse::ArgumentParser& parser) { + EvalConfig config; + config.dataset_path = parser.get("--datapath"); + config.action_type = parser.get("--type"); + config.build_param = parser.get("--create_params"); + config.index_name = parser.get("--index_name"); + config.index_path = parser.get("--index_path"); + + config.search_param = parser.get("--search_params"); + config.search_mode = parser.get("--search_mode"); + + config.top_k = parser.get("--topk"); + config.radius = parser.get("--range"); + + if (parser.get("--disable_recall")) { + config.enable_recall = false; + } + if (parser.get("--disable_percent_recall")) { + config.enable_percent_recall = false; + } + if (parser.get("--disable_memory")) { + config.enable_memory = false; + } + if (parser.get("--disable_latency")) { + config.enable_latency = false; + } + if (parser.get("--disable_qps")) { + config.enable_qps = false; + } + if (parser.get("--disable_tps")) { + config.enable_tps = false; + } + if (parser.get("--disable_percent_latency")) { + config.enable_percent_latency = false; + } + + return config; +} +} // namespace vsag::eval diff --git a/tools/eval/eval_config.h b/tools/eval/eval_config.h new file mode 100644 index 00000000..274671fc --- /dev/null +++ b/tools/eval/eval_config.h @@ -0,0 +1,50 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "argparse/argparse.hpp" + +namespace vsag::eval { +class EvalConfig { +public: + static EvalConfig + Load(argparse::ArgumentParser& parser); + +public: + std::string dataset_path; + std::string action_type; + std::string index_name; + std::string build_param; + std::string index_path; + + std::string search_param; + std::string search_mode; + int top_k{10}; + float radius{0.5f}; + + bool enable_recall{true}; + bool enable_percent_recall{true}; + bool enable_qps{true}; + bool enable_tps{true}; + bool enable_memory{true}; + bool enable_latency{true}; + bool enable_percent_latency{true}; + +private: + EvalConfig() = default; +}; + +} // namespace vsag::eval diff --git a/tools/eval_dataset.cpp b/tools/eval/eval_dataset.cpp similarity index 98% rename from tools/eval_dataset.cpp rename to tools/eval/eval_dataset.cpp index 29548ee1..24cbd5c3 100644 --- a/tools/eval_dataset.cpp +++ b/tools/eval/eval_dataset.cpp @@ -15,7 +15,7 @@ #include "eval_dataset.h" -namespace vsag { +namespace vsag::eval { EvalDatasetPtr EvalDataset::Load(const std::string& filename) { H5::H5File file(filename, H5F_ACC_RDONLY); @@ -41,6 +41,7 @@ EvalDataset::Load(const std::string& filename) { assert(train_shape.second == test_shape.second); auto obj = std::make_shared(); + obj->file_path_ = filename; obj->train_shape_ = train_shape; obj->test_shape_ = test_shape; obj->neighbors_shape_ = neighbors_shape; @@ -126,4 +127,4 @@ EvalDataset::Load(const std::string& filename) { return obj; } -} // namespace vsag +} // namespace vsag::eval diff --git a/tools/eval_dataset.h b/tools/eval/eval_dataset.h similarity index 86% rename from tools/eval_dataset.h rename to tools/eval/eval_dataset.h index 10976c46..a712b69a 100644 --- a/tools/eval_dataset.h +++ b/tools/eval/eval_dataset.h @@ -21,9 +21,10 @@ #include #include "H5Cpp.h" +#include "nlohmann/json.hpp" #include "vsag/constants.h" -namespace vsag { +namespace vsag::eval { class EvalDataset; using EvalDatasetPtr = std::shared_ptr; @@ -95,6 +96,25 @@ class EvalDataset { return test_labels_[query_id] == train_labels_[base_id]; } + std::string + GetFilePath() { + return this->file_path_; + } + + using JsonType = nlohmann::json; + JsonType + GetInfo() { + JsonType result; + JsonType temp; + temp["filepath"] = this->GetFilePath(); + temp["dim"] = this->GetDim(); + temp["base_count"] = this->GetNumberOfBase(); + temp["query_count"] = this->GetNumberOfQuery(); + temp["data_type"] = this->GetTrainDataType(); + result["dataset_info"] = temp; + return result; + }; + private: using shape_t = std::pair; static std::unordered_set @@ -144,5 +164,6 @@ class EvalDataset { size_t test_data_size_{}; std::string train_data_type_; std::string test_data_type_; + std::string file_path_; }; -} // namespace vsag +} // namespace vsag::eval diff --git a/tools/eval/monitor/duration_monitor.cpp b/tools/eval/monitor/duration_monitor.cpp new file mode 100644 index 00000000..64f8a859 --- /dev/null +++ b/tools/eval/monitor/duration_monitor.cpp @@ -0,0 +1,39 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "duration_monitor.h" + +namespace vsag::eval { + +DurationMonitor::DurationMonitor() : Monitor("duration_monitor") { +} + +void +DurationMonitor::Start() { + cur_time_ = Clock::now(); +} +void +DurationMonitor::Stop() { + auto end_time = Clock::now(); + this->duration_ = std::chrono::duration(end_time - cur_time_).count(); +} +Monitor::JsonType +DurationMonitor::GetResult() { + JsonType result; + result["duration(s)"] = this->duration_; + return result; +} + +} // namespace vsag::eval diff --git a/tools/eval/monitor/duration_monitor.h b/tools/eval/monitor/duration_monitor.h new file mode 100644 index 00000000..bf8c883a --- /dev/null +++ b/tools/eval/monitor/duration_monitor.h @@ -0,0 +1,49 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "monitor.h" + +namespace vsag::eval { + +class DurationMonitor : public Monitor { +public: + explicit DurationMonitor(); + + ~DurationMonitor() override = default; + + void + Start() override; + + void + Stop() override; + + JsonType + GetResult() override; + +private: + double duration_{0}; + + using Clock = std::chrono::high_resolution_clock; + decltype(Clock::now()) cur_time_{}; +}; + +} // namespace vsag::eval diff --git a/tools/eval/monitor/latency_monitor.cpp b/tools/eval/monitor/latency_monitor.cpp new file mode 100644 index 00000000..93acbcc2 --- /dev/null +++ b/tools/eval/monitor/latency_monitor.cpp @@ -0,0 +1,88 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "latency_monitor.h" + +namespace vsag::eval { + +LatencyMonitor::LatencyMonitor(uint64_t max_record_counts) : Monitor("latency_monitor") { + if (max_record_counts > 0) { + this->latency_records_.reserve(max_record_counts); + } +} +void +LatencyMonitor::Start() { + this->cur_time_ = Clock::now(); +} +void +LatencyMonitor::Stop() { + this->cur_time_ = Clock::now(); +} +Monitor::JsonType +LatencyMonitor::GetResult() { + JsonType result; + for (auto& metric : metrics_) { + this->cal_and_set_result(metric, result); + } + return result; +} +void +LatencyMonitor::Record(void* input) { + auto end_time = Clock::now(); + double duration = std::chrono::duration(end_time - cur_time_).count(); + this->latency_records_.emplace_back(duration); + this->cur_time_ = Clock::now(); +} +void +LatencyMonitor::SetMetrics(std::string metric) { + this->metrics_.emplace_back(std::move(metric)); +} +void +LatencyMonitor::cal_and_set_result(const std::string& metric, Monitor::JsonType& result) { + if (metric == "qps") { + auto val = this->cal_qps(); + result["qps"] = val; + } else if (metric == "avg_latency") { + auto val = this->cal_avg_latency(); + result["latency_avg(ms)"] = val; + } else if (metric == "percent_latency") { + std::vector percents = {50, 80, 90, 95, 99}; + for (auto& percent : percents) { + auto val = this->cal_latency_rate(percent * 0.01); + result["latency_detail(ms)"]["p" + std::to_string(int(percent))] = val; + } + } +} + +double +LatencyMonitor::cal_qps() { + double sum = + std::accumulate(this->latency_records_.begin(), this->latency_records_.end(), double(0)); + return static_cast(latency_records_.size()) * 1000.0 / sum; +} + +double +LatencyMonitor::cal_avg_latency() { + double sum = + std::accumulate(this->latency_records_.begin(), this->latency_records_.end(), double(0)); + return sum / static_cast(latency_records_.size()); +} +double +LatencyMonitor::cal_latency_rate(double rate) { + std::sort(this->latency_records_.begin(), this->latency_records_.end()); + auto pos = static_cast(rate * static_cast(this->latency_records_.size() - 1)); + return latency_records_[pos]; +} +} // namespace vsag::eval diff --git a/tools/eval/monitor/latency_monitor.h b/tools/eval/monitor/latency_monitor.h new file mode 100644 index 00000000..df45a619 --- /dev/null +++ b/tools/eval/monitor/latency_monitor.h @@ -0,0 +1,66 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "monitor.h" +namespace vsag::eval { + +class LatencyMonitor : public Monitor { +public: + explicit LatencyMonitor(uint64_t max_record_counts = 0); + + ~LatencyMonitor() override = default; + + void + Start() override; + + void + Stop() override; + + JsonType + GetResult() override; + + void + Record(void* input) override; + + void + SetMetrics(std::string metric); + +private: + void + cal_and_set_result(const std::string& metric, JsonType& result); + + double + cal_qps(); + + double + cal_avg_latency(); + + double + cal_latency_rate(double rate); + +private: + std::vector latency_records_; + + using Clock = std::chrono::high_resolution_clock; + decltype(Clock::now()) cur_time_; + + std::vector metrics_; +}; + +} // namespace vsag::eval diff --git a/tools/eval/monitor/memory_peak_monitor.cpp b/tools/eval/monitor/memory_peak_monitor.cpp new file mode 100644 index 00000000..9e5eb06c --- /dev/null +++ b/tools/eval/monitor/memory_peak_monitor.cpp @@ -0,0 +1,53 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "memory_peak_monitor.h" + +namespace vsag::eval { + +static std::string +GetProcFileName(pid_t pid) { + return "/proc/" + std::to_string(pid) + "/statm"; +} + +MemoryPeakMonitor::MemoryPeakMonitor() : Monitor("memory_peak_monitor") { +} + +void +MemoryPeakMonitor::Start() { + this->pid_ = getpid(); + this->infile_.open(GetProcFileName(pid_)); +} +void +MemoryPeakMonitor::Stop() { +} +Monitor::JsonType +MemoryPeakMonitor::GetResult() { + JsonType result; + result["memory_peak(KB)"] = this->max_memory_ * sysconf(_SC_PAGESIZE) / 1024; + return result; +} +void +MemoryPeakMonitor::Record(void* input) { + uint64_t val1, val2; + this->infile_ >> val1 >> val2; + this->infile_.clear(); + this->infile_.seekg(0, std::ios::beg); + if (max_memory_ < val2) { + max_memory_ = val2; + } +} + +} // namespace vsag::eval diff --git a/tools/eval/monitor/memory_peak_monitor.h b/tools/eval/monitor/memory_peak_monitor.h new file mode 100644 index 00000000..1890934f --- /dev/null +++ b/tools/eval/monitor/memory_peak_monitor.h @@ -0,0 +1,53 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "monitor.h" + +namespace vsag::eval { + +class MemoryPeakMonitor : public Monitor { +public: + explicit MemoryPeakMonitor(); + + ~MemoryPeakMonitor() override = default; + + void + Start() override; + + void + Stop() override; + + JsonType + GetResult() override; + + void + Record(void* input) override; + +private: + uint64_t max_memory_{0}; + + pid_t pid_{0}; + + std::ifstream infile_{}; +}; + +} // namespace vsag::eval diff --git a/tools/eval/monitor/monitor.cpp b/tools/eval/monitor/monitor.cpp new file mode 100644 index 00000000..5ff988b7 --- /dev/null +++ b/tools/eval/monitor/monitor.cpp @@ -0,0 +1,22 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "monitor.h" + +namespace vsag::eval { + +Monitor::Monitor(std::string name) : name_(std::move(name)) { +} +} // namespace vsag::eval diff --git a/tools/eval/monitor/monitor.h b/tools/eval/monitor/monitor.h new file mode 100644 index 00000000..b19f0aa2 --- /dev/null +++ b/tools/eval/monitor/monitor.h @@ -0,0 +1,59 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace vsag::eval { + +class Monitor { +public: + using JsonType = nlohmann::json; + +public: + explicit Monitor(std::string name); + + virtual ~Monitor() = default; + + virtual void + Start() = 0; + + virtual void + Stop() = 0; + + virtual JsonType + GetResult() = 0; + + [[nodiscard]] std::string + GetName() const { + return name_; + } + +public: + virtual void + Record(void* input = nullptr){}; + +public: + std::string name_{}; +}; + +using MonitorPtr = std::shared_ptr; + +} // namespace vsag::eval diff --git a/tools/eval/monitor/recall_monitor.cpp b/tools/eval/monitor/recall_monitor.cpp new file mode 100644 index 00000000..771774c2 --- /dev/null +++ b/tools/eval/monitor/recall_monitor.cpp @@ -0,0 +1,93 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recall_monitor.h" + +#include + +namespace vsag::eval { + +static double +get_recall(const int64_t* neighbors, const int64_t* ground_truth, size_t recall_num, size_t top_k) { + std::unordered_set neighbors_set(neighbors, neighbors + recall_num); + std::unordered_set intersection; + for (size_t i = 0; i < top_k; ++i) { + if (i < top_k && neighbors_set.count(ground_truth[i])) { + intersection.insert(ground_truth[i]); + } + } + return static_cast(intersection.size()) / static_cast(top_k); +} + +RecallMonitor::RecallMonitor(uint64_t max_record_counts) : Monitor("recall_monitor") { + if (max_record_counts > 0) { + this->recall_records_.reserve(max_record_counts); + } +} +void +RecallMonitor::Start() { +} + +void +RecallMonitor::Stop() { +} + +Monitor::JsonType +RecallMonitor::GetResult() { + JsonType result; + for (auto& metric : metrics_) { + this->cal_and_set_result(metric, result); + } + return result; +} +void +RecallMonitor::Record(void* input) { + auto [neighbors, gt, topk] = + *(reinterpret_cast*>(input)); + auto val = get_recall(neighbors, gt, topk, topk); + this->recall_records_.emplace_back(val); +} +void +RecallMonitor::SetMetrics(std::string metric) { + this->metrics_.emplace_back(std::move(metric)); +} +void +RecallMonitor::cal_and_set_result(const std::string& metric, Monitor::JsonType& result) { + if (metric == "avg_recall") { + auto val = this->cal_avg_recall(); + result["recall_avg"] = val; + } else if (metric == "percent_recall") { + std::vector percents = {0, 10, 30, 50, 70, 90}; + for (auto& percent : percents) { + auto val = this->cal_recall_rate(percent * 0.01); + result["recall_detail"]["p" + std::to_string(int(percent))] = val; + } + } +} + +double +RecallMonitor::cal_avg_recall() { + double sum = + std::accumulate(this->recall_records_.begin(), this->recall_records_.end(), double(0)); + return sum / static_cast(recall_records_.size()); +} + +double +RecallMonitor::cal_recall_rate(double rate) { + std::sort(this->recall_records_.begin(), this->recall_records_.end()); + auto pos = static_cast(rate * static_cast(this->recall_records_.size() - 1)); + return recall_records_[pos]; +} +} // namespace vsag::eval diff --git a/tools/eval/monitor/recall_monitor.h b/tools/eval/monitor/recall_monitor.h new file mode 100644 index 00000000..1977c557 --- /dev/null +++ b/tools/eval/monitor/recall_monitor.h @@ -0,0 +1,60 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "monitor.h" +namespace vsag::eval { + +class RecallMonitor : public Monitor { +public: + explicit RecallMonitor(uint64_t max_record_counts = 0); + + ~RecallMonitor() override = default; + + void + Start() override; + + void + Stop() override; + + JsonType + GetResult() override; + + void + Record(void* input) override; + + void + SetMetrics(std::string metric); + +private: + void + cal_and_set_result(const std::string& metric, JsonType& result); + + double + cal_avg_recall(); + + double + cal_recall_rate(double rate); + +private: + std::vector recall_records_; + + std::vector metrics_; +}; + +} // namespace vsag::eval diff --git a/tools/eval/search_eval_case.cpp b/tools/eval/search_eval_case.cpp new file mode 100644 index 00000000..5d644375 --- /dev/null +++ b/tools/eval/search_eval_case.cpp @@ -0,0 +1,166 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "search_eval_case.h" + +#include +#include + +#include "monitor/latency_monitor.h" +#include "monitor/memory_peak_monitor.h" +#include "monitor/recall_monitor.h" + +namespace vsag::eval { + +SearchEvalCase::SearchEvalCase(const std::string& dataset_path, + const std::string& index_path, + vsag::IndexPtr index, + EvalConfig config) + : EvalCase(dataset_path, index_path, index), config_(std::move(config)) { + auto search_mode = config.search_mode; + if (search_mode == "knn") { + this->search_type_ = SearchType::KNN; + } else if (search_mode == "range") { + this->search_type_ = SearchType::RANGE; + } else if (search_mode == "knn_filter") { + this->search_type_ = SearchType::KNN_FILTER; + } else if (search_mode == "range_filter") { + this->search_type_ = SearchType::RANGE_FILTER; + } + this->init_monitor(); +} + +void +SearchEvalCase::init_monitor() { + this->init_latency_monitor(); + this->init_recall_monitor(); + this->init_memory_monitor(); +} + +void +SearchEvalCase::init_latency_monitor() { + if (config_.enable_latency or config_.enable_tps or config_.enable_percent_latency) { + auto latency_monitor = + std::make_shared(this->dataset_ptr_->GetNumberOfQuery()); + if (config_.enable_qps) { + latency_monitor->SetMetrics("qps"); + } + if (config_.enable_latency) { + latency_monitor->SetMetrics("avg_latency"); + } + if (config_.enable_percent_latency) { + latency_monitor->SetMetrics("percent_latency"); + } + this->monitors_.emplace_back(std::move(latency_monitor)); + } +} + +void +SearchEvalCase::init_recall_monitor() { + if (config_.enable_recall or config_.enable_percent_recall) { + auto recall_monitor = + std::make_shared(this->dataset_ptr_->GetNumberOfQuery()); + if (config_.enable_recall) { + recall_monitor->SetMetrics("avg_recall"); + } + if (config_.enable_percent_recall) { + recall_monitor->SetMetrics("percent_recall"); + } + this->monitors_.emplace_back(std::move(recall_monitor)); + } +} + +void +SearchEvalCase::init_memory_monitor() { + if (config_.enable_memory) { + auto memory_peak_monitor = std::make_shared(); + this->monitors_.emplace_back(std::move(memory_peak_monitor)); + } +} + +void +SearchEvalCase::Run() { + this->deserialize(); + switch (this->search_type_) { + case KNN: + this->do_knn_search(); + break; + case RANGE: + this->do_range_search(); + break; + case KNN_FILTER: + this->do_knn_filter_search(); + break; + case RANGE_FILTER: + this->do_range_filter_search(); + break; + } + auto result = this->process_result(); + eval::SearchEvalCase::PrintResult(result); +} +void +SearchEvalCase::deserialize() { + std::ifstream infile(this->index_path_, std::ios::binary); + this->index_->Deserialize(infile); +} +void +SearchEvalCase::do_knn_search() { + uint64_t topk = config_.top_k; + auto query_count = this->dataset_ptr_->GetNumberOfQuery(); + this->logger_->Debug("query count is " + std::to_string(query_count)); + for (auto& monitor : this->monitors_) { + monitor->Start(); + for (int64_t i = 0; i < query_count; ++i) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(this->dataset_ptr_->GetDim())->Owner(false); + if (this->dataset_ptr_->GetTestDataType() == vsag::DATATYPE_FLOAT32) { + query->Float32Vectors((const float*)this->dataset_ptr_->GetOneTest(i)); + } else if (this->dataset_ptr_->GetTestDataType() == vsag::DATATYPE_INT8) { + query->Int8Vectors((const int8_t*)this->dataset_ptr_->GetOneTest(i)); + } + auto result = this->index_->KnnSearch(query, topk, config_.search_param); + if (not result.has_value()) { + std::cerr << "query error: " << result.error().message << std::endl; + exit(-1); + } + int64_t* neighbors = dataset_ptr_->GetNeighbors(i); + const int64_t* ground_truth = result.value()->GetIds(); + auto record = std::make_tuple(neighbors, ground_truth, topk); + monitor->Record(&record); + } + monitor->Stop(); + } +} +void +SearchEvalCase::do_range_search() { +} +void +SearchEvalCase::do_knn_filter_search() { +} +void +SearchEvalCase::do_range_filter_search() { +} + +SearchEvalCase::JsonType +SearchEvalCase::process_result() { + JsonType result; + for (auto& monitor : this->monitors_) { + const auto& one_result = monitor->GetResult(); + EvalCase::MergeJsonType(one_result, result); + } + return result; +} + +} // namespace vsag::eval diff --git a/tools/eval/search_eval_case.h b/tools/eval/search_eval_case.h new file mode 100644 index 00000000..ac6321a0 --- /dev/null +++ b/tools/eval/search_eval_case.h @@ -0,0 +1,83 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "eval_case.h" +#include "monitor/monitor.h" + +namespace vsag::eval { + +class SearchEvalCase : public EvalCase { +public: + using JsonType = Monitor::JsonType; + +public: + SearchEvalCase(const std::string& dataset_path, + const std::string& index_path, + vsag::IndexPtr index, + EvalConfig config); + + ~SearchEvalCase() override = default; + + void + Run() override; + +private: + enum SearchType { + KNN, + RANGE, + KNN_FILTER, + RANGE_FILTER, + }; + + void + init_monitor(); + + void + init_latency_monitor(); + + void + init_recall_monitor(); + + void + init_memory_monitor(); + + void + deserialize(); + + void + do_knn_search(); + + void + do_range_search(); + + void + do_knn_filter_search(); + + void + do_range_filter_search(); + + JsonType + process_result(); + +private: + std::vector monitors_{}; + + SearchType search_type_{SearchType::KNN}; + + EvalConfig config_; +}; +} // namespace vsag::eval diff --git a/tools/eval_performance.cpp b/tools/eval_performance.cpp new file mode 100644 index 00000000..199bdd9b --- /dev/null +++ b/tools/eval_performance.cpp @@ -0,0 +1,101 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "argparse/argparse.hpp" +#include "eval/eval_case.h" +#include "eval/eval_config.h" + +void +CheckArgs(argparse::ArgumentParser& parser) { + auto mode = parser.get("--type"); + if (mode == "search") { + auto search_mode = parser.get("--search_params"); + if (search_mode.empty()) { + throw std::runtime_error(R"(When "--type" is "search", "--search_params" is required)"); + } + } +} + +void +ParseArgs(argparse::ArgumentParser& parser, int argc, char** argv) { + parser.add_argument("--datapath", "-d") + .required() + .help("The hdf5 file path for eval"); + parser.add_argument("--type", "-t") + .required() + .choices("build", "search") + .help(R"(The eval method to select, choose from {"build", "search"})"); + parser.add_argument("--index_name", "-n") + .required() + .help("The name of index fot create index"); + parser.add_argument("--create_params", "-c") + .required() + .help("The param for create index"); + parser.add_argument("--index_path", "-i") + .default_value("/tmp/performance/index") + .help("The index path for load or save"); + parser.add_argument("--search_params", "-s") + .default_value("") + .help("The param for search"); + parser.add_argument("--search_mode") + .default_value("knn") + .choices("knn", "range", "knn_filter", "range_filter") + .help( + "The mode supported while use 'search' type," + " choose from {\"knn\", \"range\", \"knn_filter\", \"range_filter\"}"); + parser.add_argument("--topk") + .default_value(10) + .help("The topk value for knn search or knn_filter search") + .scan<'i', int>(); + parser.add_argument("--range") + .default_value(0.5f) + .help("The range value for range search or range_filter search") + .scan<'f', float>(); + parser.add_argument("--disable_recall").default_value(false).help("Enable average recall eval"); + parser.add_argument("--disable_percent_recall") + .default_value(false) + .help("Enable percent recall eval, include p0, p10, p30, p50, p70, p90"); + parser.add_argument("--disable_qps").default_value(false).help("Enable qps eval"); + parser.add_argument("--disable_tps").default_value(false).help("Enable tps eval"); + parser.add_argument("--disable_memory").default_value(false).help("Enable memory eval"); + parser.add_argument("--disable_latency") + .default_value(false) + .help("Enable average latency eval"); + parser.add_argument("--disable_percent_latency") + .default_value(false) + .help("Enable percent latency eval, include p50, p80, p90, p95, p99"); + + try { + parser.parse_args(argc, argv); + CheckArgs(parser); + } catch (const std::runtime_error& err) { + std::cerr << err.what() << std::endl; + std::cerr << parser; + } +} + +int +main(int argc, char** argv) { + argparse::ArgumentParser program("eval_performance"); + ParseArgs(program, argc, argv); + auto config = vsag::eval::EvalConfig::Load(program); + auto eval_case = vsag::eval::EvalCase::MakeInstance(config); + if (eval_case != nullptr) { + eval_case->Run(); + } +} diff --git a/tools/test_performance.cpp b/tools/test_performance.cpp index 783cb310..d547408f 100644 --- a/tools/test_performance.cpp +++ b/tools/test_performance.cpp @@ -23,7 +23,7 @@ #include #include -#include "eval_dataset.h" +#include "eval/eval_dataset.h" #include "nlohmann/json.hpp" #include "spdlog/spdlog.h" #include "vsag/vsag.h" @@ -31,6 +31,7 @@ using namespace nlohmann; using namespace spdlog; using namespace vsag; +using namespace vsag::eval; json run_test(const std::string& index_name, From 0e7044d3d89df4430b560c2719dde4089b79987e Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:50:40 +0800 Subject: [PATCH 5/8] implement optimized nndescent in diskann (#287) Signed-off-by: jinjiabao.jjb --- include/vsag/constants.h | 6 + src/allocator_wrapper.h | 5 + src/constants.cpp | 7 + src/impl/odescent_graph_builder.cpp | 375 +++++++++++++++++++++++ src/impl/odescent_graph_builder.h | 155 ++++++++++ src/impl/odescent_graph_builder_test.cpp | 109 +++++++ src/index/diskann.cpp | 33 +- src/index/diskann.h | 3 + src/index/diskann_zparameters.cpp | 57 +++- src/index/diskann_zparameters.h | 6 + tests/test_brute_force.cpp | 2 +- tests/test_diskann.cpp | 44 +++ tests/test_hgraph.cpp | 2 +- tools/test_performance.cpp | 3 +- 14 files changed, 792 insertions(+), 15 deletions(-) create mode 100644 src/impl/odescent_graph_builder.cpp create mode 100644 src/impl/odescent_graph_builder.h create mode 100644 src/impl/odescent_graph_builder_test.cpp diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 055e9e7c..adb54d44 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -61,6 +61,12 @@ extern const char* const DISKANN_PARAMETER_USE_REFERENCE; extern const char* const DISKANN_PARAMETER_USE_OPQ; extern const char* const DISKANN_PARAMETER_USE_ASYNC_IO; extern const char* const DISKANN_PARAMETER_USE_BSA; +extern const char* const DISKANN_PARAMETER_GRAPH_TYPE; +extern const char* const DISKANN_PARAMETER_ALPHA; +extern const char* const DISKANN_PARAMETER_GRAPH_ITER_TURN; +extern const char* const DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE; +extern const char* const DISKANN_GRAPH_TYPE_VAMANA; +extern const char* const DISKANN_GRAPH_TYPE_ODESCENT; extern const char* const DISKANN_PARAMETER_BEAM_SEARCH; extern const char* const DISKANN_PARAMETER_IO_LIMIT; diff --git a/src/allocator_wrapper.h b/src/allocator_wrapper.h index 5479aae8..cfdb38df 100644 --- a/src/allocator_wrapper.h +++ b/src/allocator_wrapper.h @@ -42,6 +42,11 @@ class AllocatorWrapper { return allocator_ == other.allocator_; } + bool + operator!=(const AllocatorWrapper& other) const noexcept { + return allocator_ != other.allocator_; + } + inline pointer allocate(size_type n, const_void_pointer hint = 0) { return static_cast(allocator_->Allocate(n * sizeof(value_type))); diff --git a/src/constants.cpp b/src/constants.cpp index f0220f11..8e122fc4 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -69,6 +69,13 @@ const char* const DISKANN_PARAMETER_BEAM_SEARCH = "beam_search"; const char* const DISKANN_PARAMETER_IO_LIMIT = "io_limit"; const char* const DISKANN_PARAMETER_EF_SEARCH = "ef_search"; const char* const DISKANN_PARAMETER_REORDER = "use_reorder"; +const char* const DISKANN_PARAMETER_GRAPH_TYPE = "graph_type"; +const char* const DISKANN_PARAMETER_ALPHA = "alpha"; +const char* const DISKANN_PARAMETER_GRAPH_ITER_TURN = "graph_iter_turn"; +const char* const DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE = "neighbor_sample_rate"; + +const char* const DISKANN_GRAPH_TYPE_VAMANA = "vamana"; +const char* const DISKANN_GRAPH_TYPE_ODESCENT = "odescent"; const char* const HNSW_PARAMETER_EF_RUNTIME = "ef_search"; const char* const HNSW_PARAMETER_M = "max_degree"; diff --git a/src/impl/odescent_graph_builder.cpp b/src/impl/odescent_graph_builder.cpp new file mode 100644 index 00000000..6ced73bb --- /dev/null +++ b/src/impl/odescent_graph_builder.cpp @@ -0,0 +1,375 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "odescent_graph_builder.h" + +#include + +namespace vsag { + +class LinearCongruentialGenerator { +public: + LinearCongruentialGenerator() { + auto now = std::chrono::steady_clock::now(); + auto timestamp = + std::chrono::duration_cast(now.time_since_epoch()).count(); + current_ = static_cast(timestamp); + } + + float + NextFloat() { + current_ = (A * current_ + C) % M; + return static_cast(current_) / static_cast(M); + } + +private: + unsigned int current_; + static const uint32_t A = 1664525; + static const uint32_t C = 1013904223; + static const uint32_t M = 4294967295; // 2^32 - 1 +}; + +bool +ODescent::Build() { + if (is_build_) { + return false; + } + is_build_ = true; + data_num_ = flatten_interface_->TotalCount(); + min_in_degree_ = std::min(min_in_degree_, data_num_ - 1); + Vector(data_num_, allocator_).swap(points_lock_); + Vector> old_neighbors(allocator_); + Vector> new_neighbors(allocator_); + old_neighbors.resize(data_num_, UnorderedSet(allocator_)); + new_neighbors.resize(data_num_, UnorderedSet(allocator_)); + for (int i = 0; i < data_num_; ++i) { + old_neighbors[i].reserve(max_degree_); + new_neighbors[i].reserve(max_degree_); + } + init_graph(); + { + for (int i = 0; i < turn_; ++i) { + sample_candidates(old_neighbors, new_neighbors, sample_rate_); + update_neighbors(old_neighbors, new_neighbors); + repair_no_in_edge(); + } + if (pruning_) { + prune_graph(); + add_reverse_edges(); + } + } + return true; +} + +void +ODescent::SaveGraph(std::stringstream& out) { + size_t file_offset = 0; // we will use this if we want + out.seekp(file_offset, out.beg); + size_t index_size = 24; + uint32_t max_degree = 0; + out.write((char*)&index_size, sizeof(uint64_t)); + out.write((char*)&max_degree, sizeof(uint32_t)); + uint32_t ep_u32 = 0; + size_t num_frozen = 0; + out.write((char*)&ep_u32, sizeof(uint32_t)); + out.write((char*)&num_frozen, sizeof(size_t)); + // Note: at this point, either _nd == _max_points or any frozen points have + // been temporarily moved to _nd, so _nd + _num_frozen_points is the valid + // location limit. + auto final_graph = GetGraph(); + for (uint32_t i = 0; i < data_num_; i++) { + uint32_t gk = (uint32_t)final_graph[i].size(); + out.write((char*)&gk, sizeof(uint32_t)); + out.write((char*)final_graph[i].data(), gk * sizeof(uint32_t)); + max_degree = + final_graph[i].size() > max_degree ? (uint32_t)final_graph[i].size() : max_degree; + index_size += (size_t)(sizeof(uint32_t) * (gk + 1)); + } + out.seekp(file_offset, out.beg); + out.write((char*)&index_size, sizeof(uint64_t)); + out.write((char*)&max_degree, sizeof(uint32_t)); +} + +Vector> +ODescent::GetGraph() { + Vector> extract_graph(allocator_); + extract_graph.resize(data_num_, Vector(allocator_)); + for (int i = 0; i < data_num_; ++i) { + extract_graph[i].resize(graph[i].neighbors.size()); + for (int j = 0; j < graph[i].neighbors.size(); ++j) { + extract_graph[i][j] = graph[i].neighbors[j].id; + } + } + + return extract_graph; +} + +void +ODescent::init_graph() { + graph.resize(data_num_, Linklist(allocator_)); + + auto task = [&, this](int64_t start, int64_t end) { + std::random_device rd; + std::uniform_int_distribution k_generate(0, data_num_ - 1); + std::mt19937 rng(rd()); + for (int i = start; i < end; ++i) { + UnorderedSet ids_set(allocator_); + ids_set.insert(i); + graph[i].neighbors.reserve(max_degree_); + int max_neighbors = std::min(data_num_ - 1, max_degree_); + for (int j = 0; j < max_neighbors; ++j) { + uint32_t id = i; + if (data_num_ - 1 < max_degree_) { + id = (i + j + 1) % data_num_; + } else { + while (ids_set.find(id) != ids_set.end()) { + id = k_generate(rng); + } + } + ids_set.insert(id); + auto dist = get_distance(i, id); + graph[i].neighbors.emplace_back(id, dist); + graph[i].greast_neighbor_distance = + std::max(graph[i].greast_neighbor_distance, dist); + } + } + }; + parallelize_task(task); +} + +void +ODescent::update_neighbors(Vector>& old_neighbors, + Vector>& new_neighbors) { + Vector> futures(allocator_); + auto task = [&, this](int64_t start, int64_t end) { + for (int i = start; i < end; ++i) { + Vector new_neighbors_candidates(allocator_); + for (uint32_t node_id : new_neighbors[i]) { + for (int k = 0; k < new_neighbors_candidates.size(); ++k) { + auto neighbor_id = new_neighbors_candidates[k]; + float dist = get_distance(node_id, neighbor_id); + if (dist < graph[node_id].greast_neighbor_distance) { + std::lock_guard lock(points_lock_[node_id]); + graph[node_id].neighbors.emplace_back(neighbor_id, dist); + } + if (dist < graph[neighbor_id].greast_neighbor_distance) { + std::lock_guard lock(points_lock_[neighbor_id]); + graph[neighbor_id].neighbors.emplace_back(node_id, dist); + } + } + new_neighbors_candidates.push_back(node_id); + + for (uint32_t neighbor_id : old_neighbors[i]) { + if (node_id == neighbor_id) { + continue; + } + float dist = get_distance(neighbor_id, node_id); + if (dist < graph[node_id].greast_neighbor_distance) { + std::lock_guard lock(points_lock_[node_id]); + graph[node_id].neighbors.emplace_back(neighbor_id, dist); + } + if (dist < graph[neighbor_id].greast_neighbor_distance) { + std::lock_guard lock(points_lock_[neighbor_id]); + graph[neighbor_id].neighbors.emplace_back(node_id, dist); + } + } + } + old_neighbors[i].clear(); + new_neighbors[i].clear(); + } + }; + parallelize_task(task); + + auto resize_task = [&, this](int64_t start, int64_t end) { + for (uint32_t i = start; i < end; ++i) { + auto& neighbors = graph[i].neighbors; + std::sort(neighbors.begin(), neighbors.end()); + neighbors.erase(std::unique(neighbors.begin(), neighbors.end()), neighbors.end()); + if (neighbors.size() > max_degree_) { + neighbors.resize(max_degree_); + } + graph[i].greast_neighbor_distance = neighbors.back().distance; + } + }; + parallelize_task(resize_task); +} + +void +ODescent::add_reverse_edges() { + Vector reverse_graph(allocator_); + reverse_graph.resize(data_num_, Linklist(allocator_)); + for (int i = 0; i < data_num_; ++i) { + reverse_graph[i].neighbors.reserve(max_degree_); + } + for (int i = 0; i < data_num_; ++i) { + for (const auto& node : graph[i].neighbors) { + reverse_graph[node.id].neighbors.emplace_back(i, node.distance); + } + } + + auto task = [&, this](int64_t start, int64_t end) { + for (int i = start; i < end; ++i) { + auto& neighbors = graph[i].neighbors; + neighbors.insert(neighbors.end(), + reverse_graph[i].neighbors.begin(), + reverse_graph[i].neighbors.end()); + std::sort(neighbors.begin(), neighbors.end()); + neighbors.erase(std::unique(neighbors.begin(), neighbors.end()), neighbors.end()); + if (neighbors.size() > max_degree_) { + neighbors.resize(max_degree_); + } + } + }; + parallelize_task(task); +} + +void +ODescent::sample_candidates(Vector>& old_neighbors, + Vector>& new_neighbors, + float sample_rate) { + auto task = [&, this](int64_t start, int64_t end) { + LinearCongruentialGenerator r; + for (int i = start; i < end; ++i) { + auto& neighbors = graph[i].neighbors; + for (int j = 0; j < neighbors.size(); ++j) { + float current_state = r.NextFloat(); + if (current_state < sample_rate) { + if (neighbors[j].old) { + { + std::lock_guard lock(points_lock_[i]); + old_neighbors[i].insert(neighbors[j].id); + } + { + std::lock_guard inner_lock(points_lock_[neighbors[j].id]); + old_neighbors[neighbors[j].id].insert(i); + } + } else { + { + std::lock_guard lock(points_lock_[i]); + new_neighbors[i].insert(neighbors[j].id); + } + { + std::lock_guard inner_lock(points_lock_[neighbors[j].id]); + new_neighbors[neighbors[j].id].insert(i); + } + neighbors[j].old = true; + } + } + } + } + }; + parallelize_task(task); +} + +void +ODescent::repair_no_in_edge() { + Vector in_edges_count(data_num_, 0, allocator_); + for (int i = 0; i < data_num_; ++i) { + for (auto& neigbor : graph[i].neighbors) { + in_edges_count[neigbor.id]++; + } + } + + Vector replace_pos(data_num_, std::min(data_num_ - 1, max_degree_) - 1, allocator_); + for (int i = 0; i < data_num_; ++i) { + auto& link = graph[i].neighbors; + int need_replace_loc = 0; + while (in_edges_count[i] < min_in_degree_ && need_replace_loc < max_degree_) { + uint32_t need_replace_id = link[need_replace_loc].id; + bool has_connect = false; + for (auto& neigbor : graph[need_replace_id].neighbors) { + if (neigbor.id == i) { + has_connect = true; + break; + } + } + if (replace_pos[need_replace_id] > 0 && not has_connect) { + auto& replace_node = graph[need_replace_id].neighbors[replace_pos[need_replace_id]]; + auto replace_id = replace_node.id; + if (in_edges_count[replace_id] > min_in_degree_) { + in_edges_count[replace_id]--; + replace_node.id = i; + replace_node.distance = link[need_replace_loc].distance; + in_edges_count[i]++; + } + replace_pos[need_replace_id]--; + } + need_replace_loc++; + } + } +} + +void +ODescent::prune_graph() { + Vector in_edges_count(data_num_, 0, allocator_); + for (int i = 0; i < data_num_; ++i) { + for (int j = 0; j < graph[i].neighbors.size(); ++j) { + in_edges_count[graph[i].neighbors[j].id]++; + } + } + + auto task = [&, this](int64_t start, int64_t end) { + for (int loc = start; loc < end; ++loc) { + auto& neighbors = graph[loc].neighbors; + std::sort(neighbors.begin(), neighbors.end()); + neighbors.erase(std::unique(neighbors.begin(), neighbors.end()), neighbors.end()); + Vector candidates(allocator_); + candidates.reserve(max_degree_); + for (int i = 0; i < neighbors.size(); ++i) { + bool flag = true; + int cur_in_edge = 0; + { + std::lock_guard lock(points_lock_[neighbors[i].id]); + cur_in_edge = in_edges_count[neighbors[i].id]; + } + if (cur_in_edge > min_in_degree_) { + for (int j = 0; j < candidates.size(); ++j) { + if (get_distance(neighbors[i].id, candidates[j].id) * alpha_ < + neighbors[i].distance) { + flag = false; + { + std::lock_guard lock(points_lock_[neighbors[i].id]); + in_edges_count[neighbors[i].id]--; + } + break; + } + } + } + if (flag) { + candidates.push_back(neighbors[i]); + } + } + neighbors.swap(candidates); + if (neighbors.size() > max_degree_) { + neighbors.resize(max_degree_); + } + } + }; + parallelize_task(task); +} + +void +ODescent::parallelize_task(std::function task) { + Vector> futures(allocator_); + for (int64_t i = 0; i < data_num_; i += block_size_) { + int end = std::min(i + block_size_, data_num_); + futures.push_back(thread_pool_->GeneralEnqueue(task, i, end)); + } + for (auto& future : futures) { + future.get(); + } +} + +} // namespace vsag diff --git a/src/impl/odescent_graph_builder.h b/src/impl/odescent_graph_builder.h new file mode 100644 index 00000000..a2d30502 --- /dev/null +++ b/src/impl/odescent_graph_builder.h @@ -0,0 +1,155 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "data_cell/flatten_datacell.h" +#include "logger.h" +#include "safe_allocator.h" +#include "simd/simd.h" +#include "utils.h" +#include "vsag/dataset.h" + +namespace vsag { + +struct Node { + bool old = false; + uint32_t id; + float distance; + + Node(uint32_t id, float distance) { + this->id = id; + this->distance = distance; + } + + Node(uint32_t id, float distance, bool old) { + this->id = id; + this->distance = distance; + this->old = old; + } + Node() { + } + + bool + operator<(const Node& other) const { + if (distance != other.distance) { + return distance < other.distance; + } + return old && not other.old; + } + + bool + operator==(const Node& other) const { + return id == other.id; + } +}; + +struct Linklist { + Vector neighbors; + float greast_neighbor_distance; + Linklist(Allocator*& allocator) + : neighbors(allocator), greast_neighbor_distance(std::numeric_limits::max()) { + } +}; + +class ODescent { +public: + ODescent(int64_t max_degree, + float alpha, + int64_t turn, + float sample_rate, + const FlattenInterfacePtr& flatten_interface, + Allocator* allocator, + SafeThreadPool* thread_pool, + bool pruning = true) + : max_degree_(max_degree), + alpha_(alpha), + turn_(turn), + sample_rate_(sample_rate), + flatten_interface_(flatten_interface), + pruning_(pruning), + allocator_(allocator), + graph(allocator), + points_lock_(allocator), + thread_pool_(thread_pool) { + } + + bool + Build(); + + void + SaveGraph(std::stringstream& out); + + Vector> + GetGraph(); + +private: + inline float + get_distance(uint32_t loc1, uint32_t loc2) { + return flatten_interface_->ComputePairVectors(loc1, loc2); + } + + void + init_graph(); + + void + update_neighbors(Vector>& old_neigbors, + Vector>& new_neigbors); + + void + add_reverse_edges(); + + void + sample_candidates(Vector>& old_neigbors, + Vector>& new_neigbors, + float sample_rate); + + void + repair_no_in_edge(); + + void + prune_graph(); + +private: + void + parallelize_task(std::function task); + + size_t dim_; + int64_t data_num_; + int64_t is_build_ = false; + + int64_t max_degree_; + float alpha_; + int64_t turn_; + Vector graph; + int64_t min_in_degree_ = 1; + int64_t block_size_{10000}; + Vector points_lock_; + SafeThreadPool* thread_pool_; + + bool pruning_{true}; + float sample_rate_{0.3}; + Allocator* allocator_; + + const FlattenInterfacePtr& flatten_interface_; +}; + +} // namespace vsag diff --git a/src/impl/odescent_graph_builder_test.cpp b/src/impl/odescent_graph_builder_test.cpp new file mode 100644 index 00000000..dbe7ce40 --- /dev/null +++ b/src/impl/odescent_graph_builder_test.cpp @@ -0,0 +1,109 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "odescent_graph_builder.h" + +#include +#include +#include + +#include "data_cell/flatten_interface.h" +#include "io/memory_io_parameter.h" +#include "quantization/fp32_quantizer_parameter.h" +#include "safe_allocator.h" + +size_t +calculate_overlap(const std::vector& vec1, vsag::Vector& vec2, int K) { + int size1 = std::min(K, static_cast(vec1.size())); + int size2 = std::min(K, static_cast(vec2.size())); + + std::vector top_k_vec1(vec1.begin(), vec1.begin() + size1); + std::vector top_k_vec2(vec2.begin(), vec2.begin() + size2); + + std::sort(top_k_vec1.rbegin(), top_k_vec1.rend()); + std::sort(top_k_vec2.rbegin(), top_k_vec2.rend()); + + std::set set1(top_k_vec1.begin(), top_k_vec1.end()); + std::set set2(top_k_vec2.begin(), top_k_vec2.end()); + + std::set intersection; + std::set_intersection(set1.begin(), + set1.end(), + set2.begin(), + set2.end(), + std::inserter(intersection, intersection.begin())); + return intersection.size(); +} + +TEST_CASE("build nndescent", "[ut][nndescent]") { + int64_t num_vectors = 2000; + size_t dim = 128; + int64_t max_degree = 32; + + auto vectors = new float[dim * num_vectors]; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + for (int64_t i = 0; i < dim * num_vectors; ++i) { + vectors[i] = distrib_real(rng); + } + + std::vector>> ground_truths(num_vectors); + vsag::IndexCommonParam param; + param.dim_ = dim; + param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; + param.allocator_ = vsag::SafeAllocator::FactoryDefaultAllocator(); + param.thread_pool_ = vsag::SafeThreadPool::FactoryDefaultThreadPool(); + vsag::FlattenDataCellParamPtr flatten_param = + std::make_shared(); + flatten_param->quantizer_parameter_ = std::make_shared(); + flatten_param->io_parameter_ = std::make_shared(); + vsag::FlattenInterfacePtr flatten_interface_ptr = + vsag::FlattenInterface::MakeInstance(flatten_param, param); + flatten_interface_ptr->Train(vectors, num_vectors); + flatten_interface_ptr->BatchInsertVector(vectors, num_vectors); + + vsag::DatasetPtr dataset = vsag::Dataset::Make(); + dataset->NumElements(num_vectors)->Float32Vectors(vectors)->Dim(dim)->Owner(true); + vsag::ODescent graph(max_degree, + 1, + 30, + 0.3, + flatten_interface_ptr, + param.allocator_.get(), + param.thread_pool_.get(), + false); + graph.Build(); + + auto extract_graph = graph.GetGraph(); + + float hit_edge_count = 0; + for (int i = 0; i < num_vectors; ++i) { + for (int j = 0; j < num_vectors; ++j) { + if (i != j) { + ground_truths[i].emplace_back(flatten_interface_ptr->ComputePairVectors(i, j), j); + } + } + std::sort(ground_truths[i].begin(), ground_truths[i].end()); + std::vector truths_edges(max_degree); + for (int j = 0; j < max_degree; ++j) { + truths_edges[j] = ground_truths[i][j].second; + } + hit_edge_count += calculate_overlap(truths_edges, extract_graph[i], max_degree); + } + REQUIRE(hit_edge_count / (num_vectors * max_degree) > 0.95); +} diff --git a/src/index/diskann.cpp b/src/index/diskann.cpp index a2501ec1..b73a0862 100644 --- a/src/index/diskann.cpp +++ b/src/index/diskann.cpp @@ -27,6 +27,10 @@ #include #include +#include "data_cell/flatten_datacell.h" +#include "impl/odescent_graph_builder.h" +#include "io/memory_io_parameter.h" +#include "quantization/fp32_quantizer_parameter.h" #include "vsag/constants.h" #include "vsag/errors.h" #include "vsag/expected.hpp" @@ -145,7 +149,8 @@ DiskANN::DiskANN(DiskannParameters& diskann_params, const IndexCommonParam& inde use_opq_(diskann_params.use_opq), use_bsa_(diskann_params.use_bsa), use_async_io_(diskann_params.use_async_io), - index_common_param_(index_common_param) { + diskann_params_(diskann_params), + common_param_(index_common_param) { if (not use_async_io_) { pool_ = index_common_param_.thread_pool_; } @@ -214,7 +219,31 @@ DiskANN::build(const DatasetPtr& base) { fmt::format("base.num_elements({}) must be greater than 1", data_num)); std::vector failed_locs; - { + if (diskann_params_.graph_type == DISKANN_GRAPH_TYPE_ODESCENT) { + SlowTaskTimer t("odescent build full (graph)"); + FlattenDataCellParamPtr flatten_param = + std::make_shared(); + flatten_param->quantizer_parameter_ = std::make_shared(); + flatten_param->io_parameter_ = std::make_shared(); + vsag::FlattenInterfacePtr flatten_interface_ptr = + vsag::FlattenInterface::MakeInstance(flatten_param, this->common_param_); + flatten_interface_ptr->Train(vectors, data_num); + flatten_interface_ptr->BatchInsertVector(vectors, data_num); + vsag::ODescent graph(2 * R_, + diskann_params_.alpha, + diskann_params_.turn, + diskann_params_.sample_rate, + flatten_interface_ptr, + common_param_.allocator_.get(), + common_param_.thread_pool_.get()); + graph.Build(); + graph.SaveGraph(graph_stream_); + int data_num_int32 = data_num; + int data_dim_int32 = data_dim; + tag_stream_.write((char*)&data_num_int32, sizeof(data_num_int32)); + tag_stream_.write((char*)&data_dim_int32, sizeof(data_dim_int32)); + tag_stream_.write((char*)ids, data_num * sizeof(ids)); + } else if (diskann_params_.graph_type == DISKANN_GRAPH_TYPE_VAMANA) { SlowTaskTimer t("diskann build full (graph)"); // build graph build_index_ = std::make_shared>( diff --git a/src/index/diskann.h b/src/index/diskann.h index 08cac3e6..290cf192 100644 --- a/src/index/diskann.h +++ b/src/index/diskann.h @@ -243,6 +243,9 @@ class DiskANN : public Index { mutable std::shared_mutex rw_mutex_; + IndexCommonParam common_param_; + DiskannParameters diskann_params_; + private: // Request Statistics mutable std::mutex stats_mutex_; std::shared_ptr pool_; diff --git a/src/index/diskann_zparameters.cpp b/src/index/diskann_zparameters.cpp index 0f89bebc..7ed9e48b 100644 --- a/src/index/diskann_zparameters.cpp +++ b/src/index/diskann_zparameters.cpp @@ -58,16 +58,6 @@ DiskannParameters::FromJson(JsonType& diskann_param_obj, IndexCommonParam index_ CHECK_ARGUMENT((5 <= obj.max_degree) and (obj.max_degree <= 128), fmt::format("max_degree({}) must in range[5, 128]", obj.max_degree)); - // set obj.ef_construction - CHECK_ARGUMENT( - diskann_param_obj.contains(DISKANN_PARAMETER_L), - fmt::format("parameters[{}] must contains {}", INDEX_DISKANN, DISKANN_PARAMETER_L)); - obj.ef_construction = diskann_param_obj[DISKANN_PARAMETER_L]; - CHECK_ARGUMENT((obj.max_degree <= obj.ef_construction) and (obj.ef_construction <= 1000), - fmt::format("ef_construction({}) must in range[$max_degree({}), 64]", - obj.ef_construction, - obj.max_degree)); - // set obj.pq_dims CHECK_ARGUMENT( diskann_param_obj.contains(DISKANN_PARAMETER_DISK_PQ_DIMS), @@ -105,6 +95,53 @@ DiskannParameters::FromJson(JsonType& diskann_param_obj, IndexCommonParam index_ obj.use_async_io = diskann_param_obj[DISKANN_PARAMETER_USE_ASYNC_IO]; } + // set obj.graph_type + if (diskann_param_obj.contains(DISKANN_PARAMETER_GRAPH_TYPE)) { + obj.graph_type = diskann_param_obj[DISKANN_PARAMETER_GRAPH_TYPE]; + } + + if (obj.graph_type == DISKANN_GRAPH_TYPE_VAMANA) { + // set obj.ef_construction + CHECK_ARGUMENT( + diskann_param_obj.contains(DISKANN_PARAMETER_L), + fmt::format("parameters[{}] must contains {}", INDEX_DISKANN, DISKANN_PARAMETER_L)); + obj.ef_construction = diskann_param_obj[DISKANN_PARAMETER_L]; + CHECK_ARGUMENT((obj.max_degree <= obj.ef_construction) and (obj.ef_construction <= 1000), + fmt::format("ef_construction({}) must in range[$max_degree({}), 64]", + obj.ef_construction, + obj.max_degree)); + } else if (obj.graph_type == DISKANN_GRAPH_TYPE_ODESCENT) { + // set obj.alpha + if (diskann_param_obj.contains(DISKANN_PARAMETER_ALPHA)) { + obj.alpha = diskann_param_obj[DISKANN_PARAMETER_ALPHA]; + CHECK_ARGUMENT( + (obj.alpha >= 1.0 && obj.alpha <= 2.0), + fmt::format( + "{} must in range[1.0, 2.0], now is {}", DISKANN_PARAMETER_ALPHA, obj.alpha)); + } + // set obj.turn + if (diskann_param_obj.contains(DISKANN_PARAMETER_GRAPH_ITER_TURN)) { + obj.turn = diskann_param_obj[DISKANN_PARAMETER_GRAPH_ITER_TURN]; + CHECK_ARGUMENT((obj.turn > 0), + fmt::format("{} must be greater than 0, now is {}", + DISKANN_PARAMETER_GRAPH_ITER_TURN, + obj.turn)); + } + // set obj.sample_rate + if (diskann_param_obj.contains(DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE)) { + obj.sample_rate = diskann_param_obj[DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE]; + CHECK_ARGUMENT((obj.sample_rate > 0.05 && obj.sample_rate < 0.5), + fmt::format("{} must in range[0.05, 0.5], now is {}", + DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE, + obj.sample_rate)); + } + } else { + throw std::invalid_argument(fmt::format("parameters[{}] must in [{}, {}], now is {}", + DISKANN_PARAMETER_GRAPH_TYPE, + DISKANN_GRAPH_TYPE_VAMANA, + DISKANN_GRAPH_TYPE_ODESCENT, + obj.graph_type)); + } return obj; } diff --git a/src/index/diskann_zparameters.h b/src/index/diskann_zparameters.h index 2e5f76a4..4aa3b556 100644 --- a/src/index/diskann_zparameters.h +++ b/src/index/diskann_zparameters.h @@ -44,6 +44,12 @@ struct DiskannParameters { bool use_bsa = false; bool use_async_io = false; + // use new construction method + std::string graph_type = "vamana"; + float alpha = 1.2; + int64_t turn = 40; + float sample_rate = 0.3; + private: DiskannParameters() = default; }; diff --git a/tests/test_brute_force.cpp b/tests/test_brute_force.cpp index 6cc5c3c0..cb2aaccd 100644 --- a/tests/test_brute_force.cpp +++ b/tests/test_brute_force.cpp @@ -38,7 +38,7 @@ class BruteForceTestIndex : public fixtures::TestIndex { constexpr static uint64_t base_count = 3000; const std::vector> test_cases = { - {"sq8", 0.96}, {"fp32", 0.999999}, {"sq8_uniform", 0.95}}; + {"sq8", 0.95}, {"fp32", 0.999999}, {"sq8_uniform", 0.94}}; }; TestDatasetPool BruteForceTestIndex::pool{}; diff --git a/tests/test_diskann.cpp b/tests/test_diskann.cpp index e7854d6d..42d36fb3 100644 --- a/tests/test_diskann.cpp +++ b/tests/test_diskann.cpp @@ -64,6 +64,50 @@ TEST_CASE_METHOD(fixtures::DiskANNTestIndex, "diskann build test", "[ft][index][ } } +TEST_CASE_METHOD(fixtures::DiskANNTestIndex, "diskann build and search", "[ft][index][diskann]") { + vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + auto dims = fixtures::get_common_used_dims(1); + auto metric_type = GENERATE("l2", "ip"); + auto graph_type = GENERATE("vamana", "odescent"); + const std::string name = "diskann"; + + constexpr auto build_parameter_json = R"( + {{ + "dtype": "float32", + "metric_type": "{}", + "dim": {}, + "diskann": {{ + "max_degree": 16, + "ef_construction": 200, + "graph_type": "{}", + "graph_iter_turn": 30, + "neighbor_sample_rate": 0.3, + "alpha": 1.2, + "pq_dims": 32, + "pq_sample_rate": 1 + }} + }} + )"; + constexpr auto search_param = R"( + { + "diskann": { + "ef_search": 200, + "io_limit": 200, + "beam_search": 4, + "use_reorder": true + } + } + )"; + auto count = 1000; + for (auto dim : dims) { + auto param = fmt::format(build_parameter_json, metric_type, dim, graph_type); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, count, metric_type); + TestBuildIndex(index, dataset, true); + TestKnnSearch(index, dataset, search_param, 0.99, true); + } +} + TEST_CASE("DiskAnn Float Recall", "[ft][diskann]") { int dim = 128; // Dimension of the elements int max_elements = 1000; // Maximum number of elements, should be known beforehand diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index c9b69823..c037b550 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -45,7 +45,7 @@ class HgraphTestIndex : public fixtures::TestIndex { }})"; const std::vector> test_cases = { - {"sq8_uniform,fp32", 0.98}, {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + {"sq8_uniform,fp32", 0.98}, {"sq8", 0.95}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; }; TestDatasetPool HgraphTestIndex::pool{}; diff --git a/tools/test_performance.cpp b/tools/test_performance.cpp index d547408f..d72fdc5a 100644 --- a/tools/test_performance.cpp +++ b/tools/test_performance.cpp @@ -99,7 +99,8 @@ class PerfTools { const std::string& build_parameters) { spdlog::debug("index_name: " + index_name); spdlog::debug("build_parameters: " + build_parameters); - auto index = Factory::CreateIndex(index_name, build_parameters).value(); + Engine e; + auto index = e.CreateIndex(index_name, build_parameters).value(); spdlog::debug("dataset_path: " + dataset_path); auto eval_dataset = EvalDataset::Load(dataset_path); From 34ecc8a7b414ca704eca6ac502c708686e381578 Mon Sep 17 00:00:00 2001 From: ShawnShawnYou <58975154+ShawnShawnYou@users.noreply.github.com> Date: Fri, 17 Jan 2025 18:29:27 +0800 Subject: [PATCH 6/8] support update in conjugate graph (#317) Signed-off-by: zhongxiaoyao.zxy --- src/impl/conjugate_graph.cpp | 34 +++++++++++++ src/impl/conjugate_graph.h | 3 ++ src/impl/conjugate_graph_test.cpp | 30 +++++++++++- src/index/hnsw.cpp | 8 ++++ tests/test_index.cpp | 18 +++++-- tests/test_index_old.cpp | 19 ++++++-- tests/test_multi_thread.cpp | 80 ++++++++++++++++++------------- 7 files changed, 148 insertions(+), 44 deletions(-) diff --git a/src/impl/conjugate_graph.cpp b/src/impl/conjugate_graph.cpp index 51260f85..01f9d1f1 100644 --- a/src/impl/conjugate_graph.cpp +++ b/src/impl/conjugate_graph.cpp @@ -217,4 +217,38 @@ ConjugateGraph::is_empty() const { return (this->memory_usage_ == sizeof(this->memory_usage_) + FOOTER_SIZE); } +tl::expected +ConjugateGraph::UpdateId(int64_t old_tag_id, int64_t new_tag_id) { + if (old_tag_id == new_tag_id) { + return true; + } + + // 1. update key + bool updated = false; + auto it_old_key = conjugate_graph_.find(old_tag_id); + if (it_old_key != conjugate_graph_.end()) { + auto it_new_key = conjugate_graph_.find(new_tag_id); + if (it_new_key != conjugate_graph_.end()) { + // both two id exists in graph, note that this situation should be filtered out before use this function. + return false; + } else { + conjugate_graph_[new_tag_id] = std::move(it_old_key->second); + } + conjugate_graph_.erase(it_old_key); + updated = true; + } + + // 2. update neighbors + for (auto& [key, neighbors] : conjugate_graph_) { + auto it_old_neighbor = neighbors.find(old_tag_id); + if (it_old_neighbor != neighbors.end()) { + neighbors.erase(it_old_neighbor); + neighbors.insert(new_tag_id); + updated = true; + } + } + + return updated; +} + } // namespace vsag diff --git a/src/impl/conjugate_graph.h b/src/impl/conjugate_graph.h index a068e810..03e0e80a 100644 --- a/src/impl/conjugate_graph.h +++ b/src/impl/conjugate_graph.h @@ -41,6 +41,9 @@ class ConjugateGraph { EnhanceResult(std::priority_queue>& results, const std::function& distance_of_tag) const; + tl::expected + UpdateId(int64_t old_id, int64_t new_id); + public: tl::expected Serialize() const; diff --git a/src/impl/conjugate_graph_test.cpp b/src/impl/conjugate_graph_test.cpp index 95a666ae..69aba141 100644 --- a/src/impl/conjugate_graph_test.cpp +++ b/src/impl/conjugate_graph_test.cpp @@ -279,4 +279,32 @@ TEST_CASE("serialize and deserialize with stream", "[ut][conjugate_graph]") { REQUIRE(conjugate_graph->GetMemoryUsage() == 4 + vsag::FOOTER_SIZE); re_in_stream.close(); } -} \ No newline at end of file +} + +TEST_CASE("update id", "[ut][conjugate_graph]") { + std::shared_ptr conjugate_graph = + std::make_shared(); + + REQUIRE(conjugate_graph->AddNeighbor(0, 1) == true); + REQUIRE(conjugate_graph->AddNeighbor(0, 2) == true); + REQUIRE(conjugate_graph->AddNeighbor(1, 0) == true); + REQUIRE(conjugate_graph->AddNeighbor(4, 0) == true); + + // update key + REQUIRE(conjugate_graph->UpdateId(1, 1) == true); // succ case: 1 -> 1 + REQUIRE(conjugate_graph->UpdateId(5, 4) == false); // old id don't exist + REQUIRE(conjugate_graph->UpdateId(0, 4) == false); // old id and new id both exists + REQUIRE(conjugate_graph->UpdateId(4, 5) == true); // succ case: 4 -> 5 + REQUIRE(conjugate_graph->AddNeighbor(5, 0) == false); // valid of succ case + + // update value + REQUIRE(conjugate_graph->UpdateId(2, 3) == true); // succ case: 2 -> 3 + REQUIRE(conjugate_graph->AddNeighbor(0, 3) == false); // neighbor exists + + // update both key and value + REQUIRE(conjugate_graph->UpdateId(0, -1) == true); // succ case: 0 -> -1 + REQUIRE(conjugate_graph->AddNeighbor(-1, 1) == false); + REQUIRE(conjugate_graph->AddNeighbor(-1, 3) == false); + REQUIRE(conjugate_graph->AddNeighbor(1, -1) == false); + REQUIRE(conjugate_graph->AddNeighbor(5, -1) == false); +} diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 98824afa..35d789d2 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -617,9 +617,17 @@ HNSW::update_id(int64_t old_id, int64_t new_id) { } try { + if (old_id == new_id) { + return true; + } + // note that the validation of old_id is handled within updateLabel. std::reinterpret_pointer_cast(alg_hnsw_)->updateLabel(old_id, new_id); + if (use_conjugate_graph_) { + std::unique_lock lock(rw_mutex_); + conjugate_graph_->UpdateId(old_id, new_id); + } } catch (const std::runtime_error& e) { #ifndef ENABLE_TESTS logger::warn( diff --git a/tests/test_index.cpp b/tests/test_index.cpp index ad79a474..ec78e7d5 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -116,16 +116,26 @@ TestIndex::TestUpdateId(const IndexPtr& index, REQUIRE(failed_old_res.has_value()); REQUIRE(not failed_old_res.value()); - // new id is used - auto failed_new_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]); - REQUIRE(failed_new_res.has_value()); - REQUIRE(not failed_new_res.value()); + // same id + auto succ_same_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]); + REQUIRE(succ_same_res.has_value()); + REQUIRE(succ_same_res.value()); } else { if (result.value()->GetIds()[0] == update_id_map[ids[i]]) { correct_num[round] += 1; } } } + + for (int i = 0; i < num_vectors; i++) { + if (round == 0) { + // new id is used + auto failed_new_res = + index->UpdateId(update_id_map[ids[i]], update_id_map[ids[num_vectors - i - 1]]); + REQUIRE(failed_new_res.has_value()); + REQUIRE(not failed_new_res.value()); + } + } } REQUIRE(correct_num[0] == correct_num[1]); diff --git a/tests/test_index_old.cpp b/tests/test_index_old.cpp index 7165fa5e..5bfc0ee3 100644 --- a/tests/test_index_old.cpp +++ b/tests/test_index_old.cpp @@ -1051,7 +1051,7 @@ TEST_CASE("build index with generated_build_parameters", "[ft][index]") { REQUIRE(recall > 0.95); } -TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { +TEST_CASE("int8 + freshhnsw + feedback + update", "[ft][index][hnsw]") { auto logger = vsag::Options::Instance().logger(); logger->SetLevel(vsag::Logger::Level::kDEBUG); @@ -1068,7 +1068,7 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { "dim": {}, "hnsw": {{ "max_degree": 16, - "ef_construction": 200, + "ef_construction": 100, "use_conjugate_graph": true }} }} @@ -1082,8 +1082,10 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { // generate dataset std::vector base_ids(num_base); + std::vector update_ids(num_base); for (int64_t i = 0; i < num_base; ++i) { base_ids[i] = i; + update_ids[i] = i + 2 * num_base; } auto base_vectors = fixtures::generate_int8_codes(num_base, dim); auto base = vsag::Dataset::Make(); @@ -1121,7 +1123,7 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { constexpr auto search_parameters_json = R"( {{ "hnsw": {{ - "ef_search": 50, + "ef_search": 10, "use_conjugate_graph_search": {} }} }} @@ -1146,10 +1148,16 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { REQUIRE(*index->Feedback(query, k, search_parameters) == 0); } - if (local_optimum == global_optimum) { + if (local_optimum == global_optimum or local_optimum == update_ids[global_optimum]) { correct++; } } + + if (round == 0) { + for (int i = 0; i < num_base; i++) { + REQUIRE(*index->UpdateId(base_ids[i], update_ids[i]) == true); + } + } recall[round] = correct / (1.0 * num_query); logger->Debug(fmt::format(R"(Recall: {:.4f})", recall[round])); } @@ -1157,6 +1165,7 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { logger->Debug("====summary===="); logger->Debug(fmt::format(R"(Error fix: {})", error_fix)); + REQUIRE(recall[0] < recall[1]); REQUIRE(fixtures::time_t(recall[1]) == fixtures::time_t(1.0f)); } @@ -1230,7 +1239,7 @@ TEST_CASE("hnsw + feedback with global optimum id", "[ft][index][hnsw]") { constexpr auto search_parameters_json = R"( {{ "hnsw": {{ - "ef_search": 100, + "ef_search": 10, "use_conjugate_graph_search": {} }} }} diff --git a/tests/test_multi_thread.cpp b/tests/test_multi_thread.cpp index b9351502..c12c254f 100644 --- a/tests/test_multi_thread.cpp +++ b/tests/test_multi_thread.cpp @@ -245,17 +245,10 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { })); // update id - update_id_results.push_back( - pool.enqueue([&ids, &data, &index, dim, i, max_elements]() -> bool { - auto dataset = vsag::Dataset::Make(); - dataset->Dim(dim) - ->NumElements(1) - ->Ids(ids.get() + i) - ->Float32Vectors(data.get() + i * dim) - ->Owner(false); - auto res = index->UpdateId(ids[i], ids[i] + 2 * max_elements); - return res.has_value(); - })); + update_id_results.push_back(pool.enqueue([&ids, &index, i, max_elements]() -> bool { + auto res = index->UpdateId(ids[i], ids[i] + 2 * max_elements); + return res.has_value(); + })); // update vector update_vec_results.push_back(pool.enqueue([&ids, &data, &index, dim, i]() -> bool { @@ -270,13 +263,12 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { })); // search - search_results.push_back( - pool.enqueue([&index, &ids, dim, &data, i, &str_parameters]() -> bool { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(data.get() + i * dim)->Owner(false); - auto result = index->KnnSearch(query, 2, str_parameters); - return result.has_value(); - })); + search_results.push_back(pool.enqueue([&index, dim, &data, i, &str_parameters]() -> bool { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(data.get() + i * dim)->Owner(false); + auto result = index->KnnSearch(query, 2, str_parameters); + return result.has_value(); + })); } for (int i = 0; i < max_elements; ++i) { @@ -293,12 +285,12 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn // avoid too much slow task logs vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN); - int thread_num = 32; - int dim = 256; - int max_elements = 10000; - int max_degree = 32; - int ef_construction = 200; - int ef_search = 100; + int thread_num = 16; + int dim = 32; + int max_elements = 1000; + int max_degree = 16; + int ef_construction = 50; + int ef_search = 10; int k = 10; nlohmann::json hnsw_parameters{{"max_degree", max_degree}, {"ef_construction", ef_construction}, @@ -324,8 +316,9 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn std::string str_parameters = parameters.dump(); std::vector> insert_results; - std::vector> feedback_results; - std::vector> pretrain_results; + std::vector> feedback_results; + std::vector> pretrain_results; + std::vector> update_id_results; std::vector> search_results; for (int64_t i = 0; i < max_elements / 2; ++i) { @@ -356,20 +349,30 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn return add_res.value().size(); })); + // update id + update_id_results.push_back(pool.enqueue([&ids, &index, i, max_elements]() -> bool { + auto res = index->UpdateId(ids[i], ids[i] + 2 * max_elements); + return res.has_value(); + })); + // feedback feedback_results.push_back( - pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> uint64_t { + pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> bool { auto query = vsag::Dataset::Make(); query->Dim(dim)->NumElements(1)->Int8Vectors(data.get() + i * dim)->Owner(false); auto feedback_res = index->Feedback(query, k, str_parameters); - return feedback_res.value(); + return feedback_res.has_value(); })); // pretrain - pretrain_results.push_back(pool.enqueue([&index, &ids, i, k, str_parameters]() -> uint32_t { - auto pretrain_res = index->Pretrain({ids[i]}, k, str_parameters); - return pretrain_res.value(); - })); + pretrain_results.push_back( + pool.enqueue([&index, &ids, i, k, str_parameters, max_elements]() -> bool { + auto pretrain_res = index->Pretrain({ids[i]}, k, str_parameters); + if (not pretrain_res.has_value()) { + pretrain_res = index->Pretrain({ids[i] + 2 * max_elements}, k, str_parameters); + } + return pretrain_res.has_value(); + })); // search search_results.push_back(pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> bool { @@ -380,12 +383,21 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn })); } + uint32_t succ_feedback = 0, succ_pretrain = 0; for (int64_t i = 0; i < max_elements; ++i) { REQUIRE(insert_results[i].get() == 0); if (i < max_elements / 2) { - REQUIRE(pretrain_results[i].get() >= 0); - REQUIRE(feedback_results[i].get() >= 0); + if (feedback_results[i].get()) { + succ_feedback++; + } + if (pretrain_results[i].get()) { + succ_pretrain++; + } + REQUIRE(update_id_results[i].get() == true); REQUIRE(search_results[i].get() >= 0); } } + + REQUIRE(succ_feedback > 0); + REQUIRE(succ_pretrain > 0); } From 0f368e7c3f62503d4ce0212a67af73041718f624 Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Fri, 17 Jan 2025 19:26:24 +0800 Subject: [PATCH 7/8] add examples for hnsw, diskann, enhance graph and range search (#333) Signed-off-by: jinjiabao.jjb --- examples/cpp/todo_examples/101_index_hnsw.cpp | 91 +++++++++- .../cpp/todo_examples/102_index_diskann.cpp | 109 +++++++++++- .../todo_examples/104_index_fresh_hnsw.cpp | 94 ++++++++++- .../302_feature_range_search.cpp | 91 +++++++++- .../304_feature_enhance_graph.cpp | 155 +++++++++++++++++- examples/cpp/todo_examples/CMakeLists.txt | 15 ++ src/engine.cpp | 6 +- src/index/hnsw_zparameters.cpp | 21 ++- src/parameter_generator.cpp | 8 + tests/test_index_old.cpp | 4 +- 10 files changed, 577 insertions(+), 17 deletions(-) diff --git a/examples/cpp/todo_examples/101_index_hnsw.cpp b/examples/cpp/todo_examples/101_index_hnsw.cpp index 0d435f8c..344316c2 100644 --- a/examples/cpp/todo_examples/101_index_hnsw.cpp +++ b/examples/cpp/todo_examples/101_index_hnsw.cpp @@ -11,4 +11,93 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 1000; + int64_t dim = 128; + auto ids = new int64_t[num_vectors]; + auto vectors = new float[dim * num_vectors]; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + vectors[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + // Transfer the ownership of the data (ids, vectors) to the base. + base->NumElements(num_vectors)->Dim(dim)->Ids(ids)->Float32Vectors(vectors); + + /******************* Create HNSW Index *****************/ + // hnsw_build_parameters is the configuration for building an HNSW index. + // The "dtype" specifies the data type, which supports float32 and int8. + // The "metric_type" indicates the distance metric type (e.g., cosine, inner product, and L2). + // The "dim" represents the dimensionality of the vectors, indicating the number of features for each data point. + // The "hnsw" section contains parameters specific to HNSW: + // - "max_degree": The maximum number of connections for each node in the graph. + // - "ef_construction": The size used for nearest neighbor search during graph construction, which affects both speed and the quality of the graph. + auto hnsw_build_paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "hnsw": { + "max_degree": 16, + "ef_construction": 100 + } + } + )"; + auto index = vsag::Factory::CreateIndex("hnsw", hnsw_build_paramesters).value(); + + /******************* Build HNSW Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index HNSW contains: " << index->GetNumElements() << std::endl; + } else { + std::cerr << "Failed to build index: " << build_result.error().message << std::endl; + exit(-1); + } + + /******************* KnnSearch For HNSW Index *****************/ + auto query_vector = new float[dim]; + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + + // hnsw_search_parameters is the configuration for searching in an HNSW index. + // The "hnsw" section contains parameters specific to the search operation: + // - "ef_search": The size of the dynamic list used for nearest neighbor search, which influences both recall and search speed. + auto hnsw_search_parameters = R"( + { + "hnsw": { + "ef_search": 100 + } + } + )"; + int64_t topk = 10; + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true); + auto knn_result = index->KnnSearch(query, topk, hnsw_search_parameters); + + /******************* Print Search Result *****************/ + if (knn_result.has_value()) { + auto result = knn_result.value(); + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + } else { + std::cerr << "Search Error: " << knn_result.error().message << std::endl; + } + + return 0; +} diff --git a/examples/cpp/todo_examples/102_index_diskann.cpp b/examples/cpp/todo_examples/102_index_diskann.cpp index 0d435f8c..4d13966a 100644 --- a/examples/cpp/todo_examples/102_index_diskann.cpp +++ b/examples/cpp/todo_examples/102_index_diskann.cpp @@ -11,4 +11,111 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 1000; + int64_t dim = 128; + auto ids = new int64_t[num_vectors]; + auto vectors = new float[dim * num_vectors]; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + vectors[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + // Transfer the ownership of the data (ids, vectors) to the base. + base->NumElements(num_vectors)->Dim(dim)->Ids(ids)->Float32Vectors(vectors); + + /******************* Create DiskANN Index *****************/ + // diskann_build_paramesters is the configuration for building a DiskANN index. + // The "dtype" specifies the data type, "metric_type" indicates the distance metric, + // and "dim" represents the dimensionality of the feature vectors. + // The "diskann" section contains parameters specific to DiskANN: + // - "max_degree": Maximum degree of the graph + // - "ef_construction": Construction phase efficiency factor + // - "pq_sample_rate": PQ sampling rate + // - "pq_dims": PQ dimensionality + // - "use_pq_search": Indicates whether to cache the graph in memory and use PQ vectors for retrieval (optional) + // - "use_async_io": Specifies whether to use asynchronous I/O (optional) + // - "use_bsa": Determines whether to use the BSA method for lossless filtering during the reordering phase (optional) + // Other parameters are mandatory. + auto diskann_build_paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "diskann": { + "max_degree": 16, + "ef_construction": 200, + "pq_sample_rate": 0.5, + "pq_dims": 9, + "use_pq_search": true, + "use_async_io": true, + "use_bsa": true + } + } + )"; + auto index = vsag::Factory::CreateIndex("diskann", diskann_build_paramesters).value(); + + /******************* Build DiskANN Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index DiskANN contains: " << index->GetNumElements() + << std::endl; + } else { + std::cerr << "Failed to build index: " << build_result.error().message << std::endl; + exit(-1); + } + + /******************* KnnSearch For DiskANN Index *****************/ + auto query_vector = new float[dim]; + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + + // diskann_search_parameters is the configuration for searching in a DiskANN index. + // The "diskann" section contains parameters specific to the search operation: + // - "ef_search": The search efficiency factor, which influences accuracy and speed. + // - "beam_search": The number of beams to use during the search process, balancing exploration and exploitation. + // - "io_limit": The maximum number of I/O operations allowed during the search. + // - "use_reorder": Indicates whether to perform reordering of results for better accuracy (optional). + + auto diskann_search_parameters = R"( + { + "diskann": { + "ef_search": 100, + "beam_search": 4, + "io_limit": 50, + "use_reorder": true + } + } + )"; + int64_t topk = 10; + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true); + auto knn_result = index->KnnSearch(query, topk, diskann_search_parameters); + + /******************* Print Search Result *****************/ + if (knn_result.has_value()) { + auto result = knn_result.value(); + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + } else { + std::cerr << "Search Error: " << knn_result.error().message << std::endl; + } + + return 0; +} diff --git a/examples/cpp/todo_examples/104_index_fresh_hnsw.cpp b/examples/cpp/todo_examples/104_index_fresh_hnsw.cpp index 0d435f8c..45543903 100644 --- a/examples/cpp/todo_examples/104_index_fresh_hnsw.cpp +++ b/examples/cpp/todo_examples/104_index_fresh_hnsw.cpp @@ -11,4 +11,96 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 1000; + int64_t dim = 128; + auto ids = new int64_t[num_vectors]; + auto vectors = new float[dim * num_vectors]; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + vectors[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + // Transfer the ownership of the data (ids, vectors) to the base. + base->NumElements(num_vectors)->Dim(dim)->Ids(ids)->Float32Vectors(vectors); + + /******************* Create FreshHnsw Index *****************/ + // fresh_hnsw_build_parameters is the configuration for building an FreshHNSW index. + // The "dtype" specifies the data type, which supports float32 and int8. + // The "metric_type" indicates the distance metric type (e.g., cosine, inner product, and L2). + // The "dim" represents the dimensionality of the vectors, indicating the number of features for each data point. + // The "fresh_hnsw" section contains parameters specific to FreshHNSW: + // - "max_degree": The maximum number of connections for each node in the graph. + // - "ef_construction": The size used for nearest neighbor search during graph construction, which affects both speed and the quality of the graph. + auto fresh_hnsw_build_paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "fresh_hnsw": { + "max_degree": 16, + "ef_construction": 100 + } + } + )"; + // The difference between HNSW and FreshHNSW is that FreshHNSW supports actual deletions, while HNSW only supports marked deletions. However, FreshHNSW incurs double the graph storage cost of HNSW due to the need to store reverse edges. + auto index = vsag::Factory::CreateIndex("fresh_hnsw", fresh_hnsw_build_paramesters).value(); + + /******************* Build FreshHnsw Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index FreshHnsw contains: " << index->GetNumElements() + << std::endl; + } else { + std::cerr << "Failed to build index: " << build_result.error().message << std::endl; + exit(-1); + } + + /******************* KnnSearch For FreshHnsw Index *****************/ + auto query_vector = new float[dim]; + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + + // fresh_hnsw_search_parameters is the configuration for searching in an FreshHNSW index. + // The "fresh_hnsw" section contains parameters specific to the search operation: + // - "ef_search": The size of the dynamic list used for nearest neighbor search, which influences both recall and search speed. + + auto fresh_hnsw_search_parameters = R"( + { + "fresh_hnsw": { + "ef_search": 100 + } + } + )"; + int64_t topk = 10; + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true); + auto knn_result = index->KnnSearch(query, topk, fresh_hnsw_search_parameters); + + /******************* Print Search Result *****************/ + if (knn_result.has_value()) { + auto result = knn_result.value(); + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + } else { + std::cerr << "Search Error: " << knn_result.error().message << std::endl; + } + + return 0; +} diff --git a/examples/cpp/todo_examples/302_feature_range_search.cpp b/examples/cpp/todo_examples/302_feature_range_search.cpp index 0d435f8c..b46429d9 100644 --- a/examples/cpp/todo_examples/302_feature_range_search.cpp +++ b/examples/cpp/todo_examples/302_feature_range_search.cpp @@ -11,4 +11,93 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 1000; + int64_t dim = 128; + auto ids = new int64_t[num_vectors]; + auto vectors = new float[dim * num_vectors]; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + vectors[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + // Transfer the ownership of the data (ids, vectors) to the base. + base->NumElements(num_vectors)->Dim(dim)->Ids(ids)->Float32Vectors(vectors); + + /******************* Create Hnsw Index *****************/ + auto hnsw_build_paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "hnsw": { + "max_degree": 16, + "ef_construction": 100 + } + } + )"; + auto index = vsag::Factory::CreateIndex("hnsw", hnsw_build_paramesters).value(); + + /******************* Build Hnsw Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index Hnsw contains: " << index->GetNumElements() << std::endl; + } else { + std::cerr << "Failed to build index: " << build_result.error().message << std::endl; + exit(-1); + } + + /******************* Prepare Query *****************/ + auto query_vector = new float[dim]; + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + int64_t topk = 10; + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true); + + auto hnsw_search_parameters = R"( + { + "hnsw": { + "ef_search": 100 + } + } + )"; + + /******************* Get Threshold *****************/ + auto result = index->KnnSearch(query, topk, hnsw_search_parameters); + if (not result.has_value()) { + std::cerr << "Search Error: " << result.error().message << std::endl; + } + float threshold = result.value()->GetDistances()[5]; + + /******************* RangeSearch *****************/ + auto range_result = index->RangeSearch(query, threshold, hnsw_search_parameters); + if (not range_result.has_value()) { + std::cerr << "Search Error: " << range_result.error().message << std::endl; + } + auto final_result = range_result.value(); + + /******************* Print Search Result *****************/ + std::cout << "threshold:" << threshold << std::endl; + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < final_result->GetDim(); ++i) { + std::cout << final_result->GetIds()[i] << ": " << final_result->GetDistances()[i] + << std::endl; + } + + return 0; +} diff --git a/examples/cpp/todo_examples/304_feature_enhance_graph.cpp b/examples/cpp/todo_examples/304_feature_enhance_graph.cpp index 0d435f8c..68e7cd26 100644 --- a/examples/cpp/todo_examples/304_feature_enhance_graph.cpp +++ b/examples/cpp/todo_examples/304_feature_enhance_graph.cpp @@ -11,4 +11,157 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + /******************* Prepare Base Dataset *****************/ + int dim = 128; + int base_elements = 2000; + int query_elements = 1000; + int ef_search = 10; + int64_t k = 10; + + auto base = vsag::Dataset::Make(); + std::shared_ptr base_ids(new int64_t[base_elements]); + std::shared_ptr base_data(new float[dim * base_elements]); + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distribution_real(-1, 1); + for (int i = 0; i < base_elements; i++) { + base_ids[i] = i; + + for (int d = 0; d < dim; d++) { + base_data[d + i * dim] = distribution_real(rng); + } + } + base->Dim(dim) + ->NumElements(base_elements) + ->Ids(base_ids.get()) + ->Float32Vectors(base_data.get()) + ->Owner(false); + + /******************* Build Hnsw Index *****************/ + // When you want to use EnhanceGraph, the use_conjugate_graph must be set to true + auto hnsw_build_paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "hnsw": { + "max_degree": 16, + "ef_construction": 100, + "use_conjugate_graph": true + } + } + )"; + std::shared_ptr hnsw; + if (auto index = vsag::Factory::CreateIndex("hnsw", hnsw_build_paramesters); + index.has_value()) { + hnsw = index.value(); + } else { + std::cout << "Create HNSW Error" << std::endl; + } + + if (const auto build_result = hnsw->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index constains: " << hnsw->GetNumElements() << std::endl; + } else { + std::cerr << "Failed to build index: " << build_result.error().message << std::endl; + exit(-1); + } + + /******************* Search Hnsw Index without Conjugate Graph *****************/ + // record the failed ids + std::set> failed_queries; + // use_conjugate_graph_search indicates whether to use information from the conjugate_graph to enhance the search results. + auto before_enhance_parameters = R"( + { + "hnsw": { + "ef_search": 10, + "use_conjugate_graph_search": false + } + } + )"; + { + int correct = 0; + std::cout << "====Search Stage====" << std::endl; + + for (int i = 0; i < query_elements; i++) { + auto query = vsag::Dataset::Make(); + query->Dim(dim) + ->Float32Vectors(base_data.get() + i * dim) + ->NumElements(1) + ->Owner(false); + + auto result = hnsw->KnnSearch(query, k, before_enhance_parameters); + int64_t global_optimum = i; // global optimum is itself + if (result.has_value()) { + int64_t local_optimum = result.value()->GetIds()[0]; + if (local_optimum == global_optimum) { + correct++; + } else { + failed_queries.emplace(i, global_optimum); + } + } else { + std::cerr << "Search Error: " << result.error().message << std::endl; + } + } + std::cout << "Recall: " << correct / (1.0 * query_elements) << std::endl; + } + + /******************* Enhance Phase *****************/ + // + { + int error_fixed = 0; + std::cout << "====Feedback Stage====" << std::endl; + for (auto item : failed_queries) { + auto query = vsag::Dataset::Make(); + query->Dim(dim) + ->Float32Vectors(base_data.get() + item.first * dim) + ->NumElements(1) + ->Owner(false); + error_fixed += *hnsw->Feedback(query, 1, before_enhance_parameters, item.second); + } + std::cout << "Fixed queries num: " << error_fixed << std::endl; + } + + /******************* Search Hnsw Index with Conjugate Graph *****************/ + auto after_enhance_parameters = R"( + { + "hnsw": { + "ef_search": 10, + "use_conjugate_graph_search": true + } + } + )"; + { + int correct = 0; + std::cout << "====Enhanced Search Stage====" << std::endl; + + for (int i = 0; i < query_elements; i++) { + auto query = vsag::Dataset::Make(); + query->Dim(dim) + ->Float32Vectors(base_data.get() + i * dim) + ->NumElements(1) + ->Owner(false); + + auto result = hnsw->KnnSearch(query, k, after_enhance_parameters); + int64_t global_optimum = i; // global optimum is itself + if (result.has_value()) { + int64_t local_optimum = result.value()->GetIds()[0]; + if (local_optimum == global_optimum) { + correct++; + } + } else { + std::cerr << "Search Error: " << result.error().message << std::endl; + } + } + std::cout << "Enhanced Recall: " << correct / (1.0 * query_elements) << std::endl; + } + + return 0; +} diff --git a/examples/cpp/todo_examples/CMakeLists.txt b/examples/cpp/todo_examples/CMakeLists.txt index a8a7c07f..1c1ac00c 100644 --- a/examples/cpp/todo_examples/CMakeLists.txt +++ b/examples/cpp/todo_examples/CMakeLists.txt @@ -1,5 +1,20 @@ +add_executable(101_index_hnsw 101_index_hnsw.cpp) +target_link_libraries(101_index_hnsw vsag) + +add_executable(102_index_diskann 102_index_diskann.cpp) +target_link_libraries(102_index_diskann vsag) + add_executable (103_index_hgraph 103_index_hgraph.cpp) target_link_libraries (103_index_hgraph vsag_static) +add_executable(104_index_fresh_hnsw 104_index_fresh_hnsw.cpp) +target_link_libraries(104_index_fresh_hnsw vsag) + add_executable (105_index_brute_force 105_index_brute_force.cpp) target_link_libraries (105_index_brute_force vsag_static) + +add_executable(302_feature_range_search 302_feature_range_search.cpp) +target_link_libraries(302_feature_range_search vsag) + +add_executable(304_feature_enhance_graph 304_feature_enhance_graph.cpp) +target_link_libraries(304_feature_enhance_graph vsag) diff --git a/src/engine.cpp b/src/engine.cpp index a076a6d3..145defbf 100644 --- a/src/engine.cpp +++ b/src/engine.cpp @@ -78,9 +78,9 @@ Engine::CreateIndex(const std::string& origin_name, const std::string& parameter return index; } else if (name == INDEX_FRESH_HNSW) { // read parameters from json, throw exception if not exists - CHECK_ARGUMENT(parsed_params.contains(INDEX_HNSW), - fmt::format("parameters must contains {}", INDEX_HNSW)); - auto& hnsw_param_obj = parsed_params[INDEX_HNSW]; + CHECK_ARGUMENT(parsed_params.contains(INDEX_FRESH_HNSW), + fmt::format("parameters must contains {}", INDEX_FRESH_HNSW)); + auto& hnsw_param_obj = parsed_params[INDEX_FRESH_HNSW]; auto hnsw_params = FreshHnswParameters::FromJson(hnsw_param_obj, index_common_params); logger::debug("created a fresh-hnsw index"); auto index = std::make_shared(hnsw_params, index_common_params); diff --git a/src/index/hnsw_zparameters.cpp b/src/index/hnsw_zparameters.cpp index c768dab4..022eb116 100644 --- a/src/index/hnsw_zparameters.cpp +++ b/src/index/hnsw_zparameters.cpp @@ -87,19 +87,26 @@ HnswSearchParameters::FromJson(const std::string& json_string) { HnswSearchParameters obj; // set obj.ef_search - CHECK_ARGUMENT(params.contains(INDEX_HNSW), - fmt::format("parameters must contains {}", INDEX_HNSW)); + std::string index_name; + if (params.contains(INDEX_HNSW)) { + index_name = INDEX_HNSW; + } else if (params.contains(INDEX_FRESH_HNSW)) { + index_name = INDEX_FRESH_HNSW; + } else { + throw std::invalid_argument( + fmt::format("parameters must contains {}/{}", INDEX_HNSW, INDEX_FRESH_HNSW)); + } CHECK_ARGUMENT( - params[INDEX_HNSW].contains(HNSW_PARAMETER_EF_RUNTIME), - fmt::format("parameters[{}] must contains {}", INDEX_HNSW, HNSW_PARAMETER_EF_RUNTIME)); - obj.ef_search = params[INDEX_HNSW][HNSW_PARAMETER_EF_RUNTIME]; + params[index_name].contains(HNSW_PARAMETER_EF_RUNTIME), + fmt::format("parameters[{}] must contains {}", index_name, HNSW_PARAMETER_EF_RUNTIME)); + obj.ef_search = params[index_name][HNSW_PARAMETER_EF_RUNTIME]; CHECK_ARGUMENT((1 <= obj.ef_search) and (obj.ef_search <= 1000), fmt::format("ef_search({}) must in range[1, 1000]", obj.ef_search)); // set obj.use_conjugate_graph search - if (params[INDEX_HNSW].contains(PARAMETER_USE_CONJUGATE_GRAPH_SEARCH)) { - obj.use_conjugate_graph_search = params[INDEX_HNSW][PARAMETER_USE_CONJUGATE_GRAPH_SEARCH]; + if (params[index_name].contains(PARAMETER_USE_CONJUGATE_GRAPH_SEARCH)) { + obj.use_conjugate_graph_search = params[index_name][PARAMETER_USE_CONJUGATE_GRAPH_SEARCH]; } else { obj.use_conjugate_graph_search = true; } diff --git a/src/parameter_generator.cpp b/src/parameter_generator.cpp index ae9a8911..b0014854 100644 --- a/src/parameter_generator.cpp +++ b/src/parameter_generator.cpp @@ -52,6 +52,11 @@ parameter_string(const std::string& metric_type, "ef_construction": {}, "use_conjugate_graph": {} }}, + "fresh_hnsw": {{ + "max_degree": {}, + "ef_construction": {}, + "use_conjugate_graph": {} + }}, "diskann": {{ "max_degree": {}, "ef_construction": {}, @@ -65,6 +70,9 @@ parameter_string(const std::string& metric_type, hnsw_max_degree, hnsw_ef_construction, use_conjugate_graph, + hnsw_max_degree, + hnsw_ef_construction, + use_conjugate_graph, diskann_max_degree, diskann_ef_construction, diskann_pq_dims, diff --git a/tests/test_index_old.cpp b/tests/test_index_old.cpp index 5bfc0ee3..8abf82cf 100644 --- a/tests/test_index_old.cpp +++ b/tests/test_index_old.cpp @@ -1066,7 +1066,7 @@ TEST_CASE("int8 + freshhnsw + feedback + update", "[ft][index][hnsw]") { "dtype": "int8", "metric_type": "{}", "dim": {}, - "hnsw": {{ + "fresh_hnsw": {{ "max_degree": 16, "ef_construction": 100, "use_conjugate_graph": true @@ -1122,7 +1122,7 @@ TEST_CASE("int8 + freshhnsw + feedback + update", "[ft][index][hnsw]") { use_conjugate_graph_search = (round != 0); constexpr auto search_parameters_json = R"( {{ - "hnsw": {{ + "fresh_hnsw": {{ "ef_search": 10, "use_conjugate_graph_search": {} }} From 892cbe0e36e268c46f36e5e8d7a2b10330c509eb Mon Sep 17 00:00:00 2001 From: LHT129 Date: Sun, 19 Jan 2025 10:20:15 +0800 Subject: [PATCH 8/8] introduce filter & remove example (#340) Signed-off-by: LHT129 --- .../cpp/todo_examples/301_feature_filter.cpp | 112 +++++++++++++++++- .../cpp/todo_examples/305_feature_remove.cpp | 98 ++++++++++++++- examples/cpp/todo_examples/CMakeLists.txt | 6 + 3 files changed, 214 insertions(+), 2 deletions(-) diff --git a/examples/cpp/todo_examples/301_feature_filter.cpp b/examples/cpp/todo_examples/301_feature_filter.cpp index 0d435f8c..03636053 100644 --- a/examples/cpp/todo_examples/301_feature_filter.cpp +++ b/examples/cpp/todo_examples/301_feature_filter.cpp @@ -11,4 +11,114 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + vsag::init(); + + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 10000; + int64_t dim = 128; + std::vector ids(num_vectors); + std::vector datas(num_vectors * dim); + std::mt19937 rng(47); + std::uniform_real_distribution distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + datas[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + base->NumElements(num_vectors) + ->Dim(dim) + ->Ids(ids.data()) + ->Float32Vectors(datas.data()) + ->Owner(false); + + /******************* Create HNSW Index *****************/ + auto hnsw_build_parameters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "hnsw": { + "max_degree": 16, + "ef_construction": 100 + } + } + )"; + auto index = vsag::Factory::CreateIndex("hnsw", hnsw_build_parameters).value(); + + /******************* Build HNSW Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index Hnsw contains: " << index->GetNumElements() << std::endl; + } else { + std::cerr << "Failed to build index: internalError" << build_result.error().message + << std::endl; + exit(-1); + } + + /******************* Prepare Query *****************/ + std::vector query_vector(dim); + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector.data())->Owner(false); + + /******************* Prepare Bitset Filter *****************/ + auto filter_bitset = vsag::Bitset::Make(); + for (int64_t i = 0; i < num_vectors; ++i) { + auto id = base->GetIds()[i]; + if (id % 2 == 0) { + filter_bitset->Set(id); + } + } + + /******************* Prepare Filter Function *****************/ + std::function filter_func = [](int64_t id) { return id % 2 == 0; }; + + /******************* HNSW Filter Search With Bitset *****************/ + auto hnsw_search_parameters = R"( + { + "hnsw": { + "ef_search": 100 + } + } + )"; + int64_t topk = 10; + auto search_result = index->KnnSearch(query, topk, hnsw_search_parameters, filter_bitset); + if (not search_result.has_value()) { + std::cerr << "Failed to search index with filter" << search_result.error().message + << std::endl; + exit(-1); + } + auto result = search_result.value(); + + // print result with filter, the result id is odd not even. + std::cout << "bitset filter results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + /******************* HNSW Filter Search With filter function *****************/ + search_result = index->KnnSearch(query, topk, hnsw_search_parameters, filter_func); + if (not search_result.has_value()) { + std::cerr << "Failed to search index with filter" << search_result.error().message + << std::endl; + exit(-1); + } + result = search_result.value(); + + // print result with filter, the result id is odd not even. + std::cout << "function filter results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } +} diff --git a/examples/cpp/todo_examples/305_feature_remove.cpp b/examples/cpp/todo_examples/305_feature_remove.cpp index 0d435f8c..02da8e5c 100644 --- a/examples/cpp/todo_examples/305_feature_remove.cpp +++ b/examples/cpp/todo_examples/305_feature_remove.cpp @@ -11,4 +11,100 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include + +#include + +int +main(int argc, char** argv) { + vsag::init(); + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 10000; + int64_t dim = 128; + std::vector ids(num_vectors); + std::vector datas(num_vectors * dim); + std::mt19937 rng(47); + std::uniform_real_distribution distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + datas[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + base->NumElements(num_vectors) + ->Dim(dim) + ->Ids(ids.data()) + ->Float32Vectors(datas.data()) + ->Owner(false); + + /******************* Create HNSW Index *****************/ + auto hnsw_build_paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "hnsw": { + "max_degree": 16, + "ef_construction": 100 + } + } + )"; + auto index = vsag::Factory::CreateIndex("hnsw", hnsw_build_paramesters).value(); + + /******************* Build HNSW Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index Hnsw contains: " << index->GetNumElements() << std::endl; + } else { + std::cerr << "Failed to build index: " << build_result.error().message << std::endl; + exit(-1); + } + + /******************* Prepare Query *****************/ + std::vector query_vector(dim); + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector.data())->Owner(false); + + /******************* HNSW Origin KnnSearch *****************/ + auto hnsw_search_parameters = R"( + { + "hnsw": { + "ef_search": 100 + } + } + )"; + int64_t topk = 10; + auto search_result = index->KnnSearch(query, topk, hnsw_search_parameters); + if (not search_result.has_value()) { + std::cerr << "Failed to search index" << search_result.error().message << std::endl; + exit(-1); + } + auto result = search_result.value(); + + std::cout << "origin results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + /******************* HNSW Remove Some result ids *****************/ + for (int64_t i = 0; i < 5; ++i) { + index->Remove(result->GetIds()[i]); + } + + /******************* HNSW KnnSearch After Remove *****************/ + search_result = index->KnnSearch(query, topk, hnsw_search_parameters); + if (not search_result.has_value()) { + std::cerr << "Failed to search index" << search_result.error().message << std::endl; + exit(-1); + } + result = search_result.value(); + std::cout << "after delete results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } +} diff --git a/examples/cpp/todo_examples/CMakeLists.txt b/examples/cpp/todo_examples/CMakeLists.txt index 1c1ac00c..3d1e3b3b 100644 --- a/examples/cpp/todo_examples/CMakeLists.txt +++ b/examples/cpp/todo_examples/CMakeLists.txt @@ -13,8 +13,14 @@ target_link_libraries(104_index_fresh_hnsw vsag) add_executable (105_index_brute_force 105_index_brute_force.cpp) target_link_libraries (105_index_brute_force vsag_static) +add_executable (301_feature_filter 301_feature_filter.cpp) +target_link_libraries (301_feature_filter vsag_static) + add_executable(302_feature_range_search 302_feature_range_search.cpp) target_link_libraries(302_feature_range_search vsag) add_executable(304_feature_enhance_graph 304_feature_enhance_graph.cpp) target_link_libraries(304_feature_enhance_graph vsag) + +add_executable (305_feature_remove 305_feature_remove.cpp) +target_link_libraries (305_feature_remove vsag_static)