From 42f054e63ada9d82bab04f1a44c0850364482de0 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Wed, 22 Jan 2025 20:56:35 +0800 Subject: [PATCH] Add a batch interface for getDistanceByLabel Signed-off-by: zourunxin.zrx --- include/vsag/index.h | 14 ++++++++ src/algorithm/hnswlib/algorithm_interface.h | 6 ++++ src/algorithm/hnswlib/hnswalg.cpp | 28 +++++++++++++++ src/algorithm/hnswlib/hnswalg.h | 4 +++ src/algorithm/hnswlib/hnswalg_static.h | 24 +++++++++++++ src/index/hnsw.h | 5 +++ tests/test_hnsw_new.cpp | 40 +++++++++++++++++++++ tests/test_index.cpp | 24 +++++++++++++ tests/test_index.h | 5 +++ 9 files changed, 150 insertions(+) diff --git a/include/vsag/index.h b/include/vsag/index.h index 128eab54..6bdc315f 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -249,6 +249,20 @@ class Index { throw std::runtime_error("Index doesn't support get distance by id"); }; + /** + * @brief Calculate the distance between the query and the vector of the given ID for batch. + * + * @param vectors is the embedding of query + * @param ids is the unique identifier of the vector to be calculated in the index. + * @param count is the count of ids + * @param distances is the distances between the query and the vector of the given ID + * @return result is valid distance of input ids. '-1' indicates an invalid distance. + */ + virtual tl::expected + CalcBatchDistanceById(const float* vectors, const int64_t* ids, int64_t count) const { + throw std::runtime_error("Index doesn't support get distance by id"); + }; + /** * @brief Checks if the specified feature is supported by the index. * diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 0bd2ae49..ba710109 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -24,6 +24,9 @@ #include "space_interface.h" #include "stream_reader.h" #include "typing.h" +#include "vsag/dataset.h" +#include "vsag/errors.h" +#include "vsag/expected.hpp" namespace hnswlib { @@ -66,6 +69,9 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; + virtual tl::expected + getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) = 0; + virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index cadee13b..4ea2de3c 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,6 +171,34 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } +tl::expected +HierarchicalNSW::getBatchDistanceByLabel(const int64_t* ids, + const void* data_point, + int64_t count) { + std::shared_lock lock_table(label_lookup_lock_); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); + auto* distances = (float*)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); + std::shared_ptr normalize_query; + normalizeVector(data_point, normalize_query); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(ids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = + fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + valid_cnt++; + } + } + result->NumElements(count); + return std::move(result); +} + bool HierarchicalNSW::isValidLabel(LabelType label) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 1228c5ed..77196c18 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -38,6 +38,7 @@ #include "algorithm_interface.h" #include "block_manager.h" #include "visited_list_pool.h" +#include "vsag/dataset.h" namespace hnswlib { using InnerIdType = vsag::InnerIdType; using linklistsizeint = unsigned int; @@ -146,6 +147,9 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; + tl::expected + getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) override; + bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 570e9373..014159c3 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -262,6 +262,30 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } + tl::expected + getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) override { + std::unique_lock lock_table(label_lookup_lock); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); + auto* distances = (float*)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(ids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = + fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + valid_cnt++; + } + } + result->NumElements(valid_cnt); + return std::move(result); + } + void copyDataByLabel(LabelType label, void* data_point) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 831416af..8f88f108 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -148,6 +148,11 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; + virtual tl::expected + CalcBatchDistanceById(const float* vectors, const int64_t* ids, int64_t count) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(ids, vectors, count)); + }; + [[nodiscard]] bool CheckFeature(IndexFeature feature) const override; diff --git a/tests/test_hnsw_new.cpp b/tests/test_hnsw_new.cpp index 196faad6..d0b4f58d 100644 --- a/tests/test_hnsw_new.cpp +++ b/tests/test_hnsw_new.cpp @@ -316,6 +316,46 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Id", "[ft][hn } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Batch Calc Dis Id", "[ft][hnsw]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + const std::string name = "hnsw"; + auto search_param = fmt::format(search_param_tmp, 100); + for (auto& dim : dims) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateHNSWBuildParametersString(metric_type, dim); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + TestBatchCalcDistanceById(index, dataset); + vsag::Options::Instance().set_block_size_limit(origin_size); + } +} + +TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, + "static HNSW Batch Calc Dis Id", + "[ft][hnsw]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2"); + auto use_static = GENERATE(true); + const std::string name = "hnsw"; + auto search_param = fmt::format(search_param_tmp, 100); + for (auto& dim : dims) { + if (dim % 4 != 0) { + dim = ((dim / 4) + 1) * 4; + } + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateHNSWBuildParametersString(metric_type, dim, use_static); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + TestBatchCalcDistanceById(index, dataset); + vsag::Options::Instance().set_block_size_limit(origin_size); + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Vector", "[ft][hnsw]") { auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); diff --git a/tests/test_index.cpp b/tests/test_index.cpp index b6eaa67f..8f5feaa4 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -420,6 +420,30 @@ TestIndex::TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dat } } +void +TestIndex::TestBatchCalcDistanceById(const IndexPtr& index, + const TestDatasetPtr& dataset, + float error) { + auto queries = dataset->query_; + auto query_count = queries->GetNumElements(); + auto dim = queries->GetDim(); + auto gts = dataset->ground_truth_; + auto gt_topK = dataset->top_k; + for (auto i = 0; i < query_count; ++i) { + auto query = vsag::Dataset::Make(); + query->NumElements(1) + ->Dim(dim) + ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Owner(false); + auto result = index->CalcBatchDistanceById( + query->GetFloat32Vectors(), gts->GetIds() + (i * gt_topK), gt_topK); + for (auto j = 0; j < gt_topK; ++j) { + REQUIRE(std::abs(gts->GetDistances()[i * gt_topK + j] - + result.value()->GetDistances()[j]) < error); + } + } +} + void TestIndex::TestSerializeFile(const IndexPtr& index_from, const IndexPtr& index_to, diff --git a/tests/test_index.h b/tests/test_index.h index 24cc159d..bedb0de4 100644 --- a/tests/test_index.h +++ b/tests/test_index.h @@ -110,6 +110,11 @@ class TestIndex { static void TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dataset, float error = 1e-5); + static void + TestBatchCalcDistanceById(const IndexPtr& index, + const TestDatasetPtr& dataset, + float error = 1e-5); + static void TestSerializeFile(const IndexPtr& index_from, const IndexPtr& index_to,