Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a batch interface for getDistanceByLabel #337

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,20 @@ class Index {
throw std::runtime_error("Index doesn't support get distance by id");
};

/**
* @brief Calculate the distance between the query and the vector of the given ID for batch.
*
* @param vectors is the embedding of query
* @param ids is the unique identifier of the vector to be calculated in the index.
* @param count is the count of ids
* @param distances is the distances between the query and the vector of the given ID
* @return result is valid distance of input ids. '-1' indicates an invalid distance.
*/
virtual tl::expected<DatasetPtr, Error>
CalcBatchDistanceById(const float* vectors, const int64_t* ids, int64_t count) const {
throw std::runtime_error("Index doesn't support get distance by id");
};

/**
* @brief Checks if the specified feature is supported by the index.
*
Expand Down
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/algorithm_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "space_interface.h"
#include "stream_reader.h"
#include "typing.h"
#include "vsag/dataset.h"
#include "vsag/errors.h"
#include "vsag/expected.hpp"

namespace hnswlib {

Expand Down Expand Up @@ -66,6 +69,9 @@ class AlgorithmInterface {
virtual float
getDistanceByLabel(LabelType label, const void* data_point) = 0;

virtual tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) = 0;

virtual const float*
getDataByLabel(LabelType label) const = 0;

Expand Down
28 changes: 28 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,34 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) {
return dist;
}

tl::expected<vsag::DatasetPtr, vsag::Error>
HierarchicalNSW::getBatchDistanceByLabel(const int64_t* ids,
const void* data_point,
int64_t count) {
std::shared_lock lock_table(label_lookup_lock_);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
result->Owner(true, allocator_);
auto* distances = (float*)allocator_->Allocate(sizeof(float) * count);
result->Distances(distances);
std::shared_ptr<float[]> normalize_query;
normalizeVector(data_point, normalize_query);
for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(ids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist =
fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
distances[i] = dist;
valid_cnt++;
}
}
result->NumElements(count);
return std::move(result);
}

bool
HierarchicalNSW::isValidLabel(LabelType label) {
std::shared_lock lock_table(label_lookup_lock_);
Expand Down
4 changes: 4 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "algorithm_interface.h"
#include "block_manager.h"
#include "visited_list_pool.h"
#include "vsag/dataset.h"
namespace hnswlib {
using InnerIdType = vsag::InnerIdType;
using linklistsizeint = unsigned int;
Expand Down Expand Up @@ -146,6 +147,9 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
float
getDistanceByLabel(LabelType label, const void* data_point) override;

tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) override;

bool
isValidLabel(LabelType label) override;

Expand Down
24 changes: 24 additions & 0 deletions src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,30 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
return dist;
}

tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) override {
std::unique_lock<std::mutex> lock_table(label_lookup_lock);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
result->Owner(true, allocator_);
auto* distances = (float*)allocator_->Allocate(sizeof(float) * count);
result->Distances(distances);
for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(ids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist =
fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
distances[i] = dist;
valid_cnt++;
}
}
result->NumElements(valid_cnt);
return std::move(result);
}

void
copyDataByLabel(LabelType label, void* data_point) override {
std::unique_lock lock_table(label_lookup_lock);
Expand Down
5 changes: 5 additions & 0 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class HNSW : public Index {
SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector));
};

virtual tl::expected<DatasetPtr, Error>
CalcBatchDistanceById(const float* vectors, const int64_t* ids, int64_t count) const override {
SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(ids, vectors, count));
};

[[nodiscard]] bool
CheckFeature(IndexFeature feature) const override;

Expand Down
40 changes: 40 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,46 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Id", "[ft][hn
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Batch Calc Dis Id", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestBatchCalcDistanceById(index, dataset);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex,
"static HNSW Batch Calc Dis Id",
"[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2");
auto use_static = GENERATE(true);
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
if (dim % 4 != 0) {
dim = ((dim / 4) + 1) * 4;
}
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim, use_static);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestBatchCalcDistanceById(index, dataset);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Vector", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
24 changes: 24 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,30 @@ TestIndex::TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dat
}
}

void
TestIndex::TestBatchCalcDistanceById(const IndexPtr& index,
const TestDatasetPtr& dataset,
float error) {
auto queries = dataset->query_;
auto query_count = queries->GetNumElements();
auto dim = queries->GetDim();
auto gts = dataset->ground_truth_;
auto gt_topK = dataset->top_k;
for (auto i = 0; i < query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto result = index->CalcBatchDistanceById(
query->GetFloat32Vectors(), gts->GetIds() + (i * gt_topK), gt_topK);
for (auto j = 0; j < gt_topK; ++j) {
REQUIRE(std::abs(gts->GetDistances()[i * gt_topK + j] -
result.value()->GetDistances()[j]) < error);
}
}
}

void
TestIndex::TestSerializeFile(const IndexPtr& index_from,
const IndexPtr& index_to,
Expand Down
5 changes: 5 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ class TestIndex {
static void
TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dataset, float error = 1e-5);

static void
TestBatchCalcDistanceById(const IndexPtr& index,
const TestDatasetPtr& dataset,
float error = 1e-5);

static void
TestSerializeFile(const IndexPtr& index_from,
const IndexPtr& index_to,
Expand Down