Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a batch interface for getDistanceByLabel #309

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,22 @@ class Index {
throw std::runtime_error("Index doesn't support get distance by id");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need to keep this interface anymore; it is just a special form of batch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This involves a lot of test files. Can we leave it unchanged for now?

};

/**
* @brief Calculate the distance between the query and the vector of the given ID for batch.
*
* @param count is the count of vids
* @param vids is the unique identifier of the vector to be calculated in the index.
* @param vector is the embedding of query
* @param distances is the distances between the query and the vector of the given ID
* @return result is valid distance of input vids. '-1' indicates an invalid distance.
*/
virtual tl::expected<DatasetPtr, Error>
CalcBatchDistanceById(int64_t count,
const int64_t *vids,
const float* vector) const {
throw std::runtime_error("Index doesn't support get distance by id");
};

/**
* @brief Checks if the specified feature is supported by the index.
*
Expand Down
8 changes: 8 additions & 0 deletions src/algorithm/hnswlib/algorithm_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "space_interface.h"
#include "stream_reader.h"
#include "typing.h"
#include "vsag/dataset.h"
#include "vsag/expected.hpp"
#include "vsag/errors.h"

namespace hnswlib {

Expand Down Expand Up @@ -69,6 +72,11 @@ class AlgorithmInterface {
virtual float
getDistanceByLabel(LabelType label, const void* data_point) = 0;

virtual tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(int64_t count,
const int64_t *vids,
const void *data_point) = 0;

virtual const float*
getDataByLabel(LabelType label) const = 0;

Expand Down
24 changes: 24 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,30 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) {
return dist;
}

tl::expected<vsag::DatasetPtr, vsag::Error>
HierarchicalNSW::getBatchDistanceByLabel(int64_t count,
const int64_t *vids,
const void* data_point) {
std::shared_lock lock_table(label_lookup_lock_);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
auto *distances = (float *)allocator_->Allocate(sizeof(float) * count);
for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(vids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have cal distance by id interface, why not use it

float
HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) 

Copy link
Author

@Carrot-77 Carrot-77 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to calculate distances in batches. Calling getDistanceByLabel requires a shared lock for each LabelType, which may have a certain impact on performance in large batch scenarios.

distances[i] = dist;
valid_cnt++;
}
}
result->NumElements(valid_cnt)->Owner(true, allocator_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to directly include the distances in the dataset to avoid potential memory leaks

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

result->Distances(distances);
return std::move(result);
}

bool
HierarchicalNSW::isValidLabel(LabelType label) {
std::shared_lock lock_table(label_lookup_lock_);
Expand Down
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "algorithm_interface.h"
#include "block_manager.h"
#include "visited_list_pool.h"
#include "vsag/dataset.h"
namespace hnswlib {
using InnerIdType = vsag::InnerIdType;
using linklistsizeint = unsigned int;
Expand Down Expand Up @@ -146,6 +147,11 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
float
getDistanceByLabel(LabelType label, const void* data_point) override;

tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(int64_t count,
const int64_t *vids,
const void* data_point) override;

bool
isValidLabel(LabelType label) override;

Expand Down
24 changes: 24 additions & 0 deletions src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,30 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
return dist;
}

tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(int64_t count,
const int64_t *vids,
const void* data_point) override {
std::unique_lock<std::mutex> lock_table(label_lookup_lock);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
auto *distances = (float *)allocator_->Allocate(sizeof(float) * count);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(vids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicitly specify at the interface that -1 indicates an invalid distance

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already marked

} else {
InnerIdType internal_id = search->second;
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
distances[i] = dist;
valid_cnt++;
}
}
result->NumElements(valid_cnt)->Owner(true, allocator_);
result->Distances(distances);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return std::move(result);
}

bool
isValidLabel(LabelType label) override {
std::unique_lock<std::mutex> lock_table(label_lookup_lock);
Expand Down
7 changes: 7 additions & 0 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ class HNSW : public Index {
SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector));
};

virtual tl::expected<DatasetPtr, Error>
CalcBatchDistanceById(int64_t count,
const int64_t *vids,
const float* vector) const override {
SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector));
};

[[nodiscard]] bool
CheckFeature(IndexFeature feature) const override;

Expand Down