diff --git a/include/vsag/index.h b/include/vsag/index.h index 96812420..ebab6df5 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -249,6 +249,22 @@ class Index { throw std::runtime_error("Index doesn't support get distance by id"); }; + /** + * @brief Calculate the distance between the query and the vector of the given ID for batch. + * + * @param count is the count of vids + * @param vids is the unique identifier of the vector to be calculated in the index. + * @param vector is the embedding of query + * @param distances is the distances between the query and the vector of the given ID + * @return result is valid distance of input vids. '-1' indicates an invalid distance. + */ + virtual tl::expected + CalcBatchDistanceById(int64_t count, + const int64_t *vids, + const float* vector) const { + throw std::runtime_error("Index doesn't support get distance by id"); + }; + /** * @brief Checks if the specified feature is supported by the index. * diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 6d6e5b7e..da781531 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -24,6 +24,9 @@ #include "space_interface.h" #include "stream_reader.h" #include "typing.h" +#include "vsag/dataset.h" +#include "vsag/expected.hpp" +#include "vsag/errors.h" namespace hnswlib { @@ -66,6 +69,11 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; + virtual tl::expected + getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void *data_point) = 0; + virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 578951a2..795dc8e5 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,6 +171,31 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } +tl::expected +HierarchicalNSW::getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void* data_point) { + std::shared_lock lock_table(label_lookup_lock_); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); + auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + valid_cnt++; + } + } + result->NumElements(valid_cnt); + return std::move(result); +} + bool HierarchicalNSW::isValidLabel(LabelType label) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 440bd6d6..2b641fe2 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -38,6 +38,7 @@ #include "algorithm_interface.h" #include "block_manager.h" #include "visited_list_pool.h" +#include "vsag/dataset.h" namespace hnswlib { using InnerIdType = vsag::InnerIdType; using linklistsizeint = unsigned int; @@ -146,6 +147,11 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; + tl::expected + getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void* data_point) override; + bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 0e51d551..0d2afa54 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -262,6 +262,31 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } + tl::expected + getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void* data_point) override { + std::unique_lock lock_table(label_lookup_lock); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); + auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + valid_cnt++; + } + } + result->NumElements(valid_cnt); + return std::move(result); + } + bool isValidLabel(LabelType label) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 1a170ea8..6653ae45 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -145,6 +145,13 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; + virtual tl::expected + CalcBatchDistanceById(int64_t count, + const int64_t *vids, + const float* vector) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector)); + }; + [[nodiscard]] bool CheckFeature(IndexFeature feature) const override;