-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a batch interface for getDistanceByLabel #309
Conversation
…o batch_cal_dis_from_id
@@ -249,6 +249,23 @@ class Index { | |||
throw std::runtime_error("Index doesn't support get distance by id"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to keep this interface anymore; it is just a special form of batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This involves a lot of test files. Can we leave it unchanged for now?
include/vsag/index.h
Outdated
CalcBatchDistanceById(int64_t count, | ||
const int64_t *vids, | ||
const float* vector, | ||
float *&distances) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can return a Dataset as the result set, and let the Dataset automatically manage the allocated memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
@@ -81,6 +81,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> { | |||
void* dist_func_param_{nullptr}; | |||
|
|||
mutable std::mutex label_lookup_lock; // lock for label_lookup_ | |||
mutable std::shared_mutex shared_label_lookup_lock; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why here add a new lock?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
for (int i = 0; i < count; i++) { | ||
auto search = label_lookup_.find(vids[i]); | ||
if (search == label_lookup_.end()) { | ||
distances[i] = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explicitly specify at the interface that -1 indicates an invalid distance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already marked
src/algorithm/hnswlib/hnswalg.cpp
Outdated
valid_cnt++; | ||
} | ||
} | ||
result->NumElements(valid_cnt)->Owner(true, allocator_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's best to directly include the distances in the dataset to avoid potential memory leaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
} | ||
} | ||
result->NumElements(valid_cnt)->Owner(true, allocator_); | ||
result->Distances(distances); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
distances[i] = -1; | ||
} else { | ||
InnerIdType internal_id = search->second; | ||
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have cal distance by id interface, why not use it
float
HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to calculate distances in batches. Calling getDistanceByLabel requires a shared lock for each LabelType, which may have a certain impact on performance in large batch scenarios.
std::unique_lock<std::mutex> lock_table(label_lookup_lock); | ||
int64_t valid_cnt = 0; | ||
auto result = vsag::Dataset::Make(); | ||
auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Signed-off-by: zourunxin.zrx <[email protected]>
Signed-off-by: wxy407827 <[email protected]>
Signed-off-by: wxy407827 <[email protected]>
Signed-off-by: wxy407827 <[email protected]>
Signed-off-by: wxy407827 <[email protected]>
Signed-off-by: jinjiabao.jjb <[email protected]>
Signed-off-by: zourunxin.zrx <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Forgot to add the signature, modified to: #337 |
Add a batch interface for getDistanceByLabel
param count is the count of vids
param vids is the unique identifier of the vector to be calculated in the index.
param vector is the embedding of query
param distances is the distances between the query and the vector of the given ID
return result is valid distance of input vids.
virtual tl::expected<int64_t, Error>CalcBatchDistanceById(int64_t count, const int64_t vids, const float vector, float *&distances)