-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a batch interface for getDistanceByLabel #309
Changes from 5 commits
0165bcf
ff71e34
d32e60a
b08807d
e8f3a92
b720efc
47a2fea
01d8d31
c0c2cbb
5610733
9059ef1
be4fbd2
f4a330a
a41a874
cd24b67
a519cd4
338d71b
9756f58
26b329e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -249,6 +249,23 @@ class Index { | |
throw std::runtime_error("Index doesn't support get distance by id"); | ||
}; | ||
|
||
/** | ||
* @brief Calculate the distance between the query and the vector of the given ID for batch. | ||
* | ||
* @param count is the count of vids | ||
* @param vids is the unique identifier of the vector to be calculated in the index. | ||
* @param vector is the embedding of query | ||
* @param distances is the distances between the query and the vector of the given ID | ||
* @return result is valid distance of input vids. | ||
*/ | ||
virtual tl::expected<int64_t, Error> | ||
CalcBatchDistanceById(int64_t count, | ||
const int64_t *vids, | ||
const float* vector, | ||
float *&distances) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can return a Dataset as the result set, and let the Dataset automatically manage the allocated memory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified |
||
throw std::runtime_error("Index doesn't support get distance by id"); | ||
}; | ||
|
||
/** | ||
* @brief Checks if the specified feature is supported by the index. | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,6 +171,28 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { | |
return dist; | ||
} | ||
|
||
int64_t | ||
HierarchicalNSW::getBatchDistanceByLabel(int64_t count, | ||
const int64_t *vids, | ||
const void* data_point, | ||
float *&distances) { | ||
std::shared_lock lock_table(label_lookup_lock_); | ||
int64_t ret_cnt = 0; | ||
distances = (float *)allocator_->Allocate(sizeof(float) * count); | ||
for (int i = 0; i < count; i++) { | ||
auto search = label_lookup_.find(vids[i]); | ||
if (search == label_lookup_.end()) { | ||
distances[i] = -1; | ||
} else { | ||
InnerIdType internal_id = search->second; | ||
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have cal distance by id interface, why not use it
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to calculate distances in batches. Calling getDistanceByLabel requires a shared lock for each LabelType, which may have a certain impact on performance in large batch scenarios. |
||
distances[i] = dist; | ||
ret_cnt++; | ||
} | ||
} | ||
return ret_cnt; | ||
} | ||
|
||
bool | ||
HierarchicalNSW::isValidLabel(LabelType label) { | ||
std::shared_lock lock_table(label_lookup_lock_); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> { | |
void* dist_func_param_{nullptr}; | ||
|
||
mutable std::mutex label_lookup_lock; // lock for label_lookup_ | ||
mutable std::shared_mutex shared_label_lookup_lock; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why here add a new lock? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified |
||
std::unordered_map<LabelType, tableint> label_lookup_; | ||
|
||
std::default_random_engine level_generator_; | ||
|
@@ -262,6 +263,28 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> { | |
return dist; | ||
} | ||
|
||
int64_t | ||
getBatchDistanceByLabel(int64_t count, | ||
const int64_t *vids, | ||
const void* data_point, | ||
float *&distances) override { | ||
std::shared_lock<std::shared_mutex> lock_table(shared_label_lookup_lock); | ||
int64_t ret_cnt = 0; | ||
distances = (float *)allocator_->Allocate(sizeof(float) * count); | ||
for (int i = 0; i < count; i++) { | ||
auto search = label_lookup_.find(vids[i]); | ||
if (search == label_lookup_.end()) { | ||
distances[i] = -1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. explicitly specify at the interface that -1 indicates an invalid distance There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already marked |
||
} else { | ||
InnerIdType internal_id = search->second; | ||
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); | ||
distances[i] = dist; | ||
ret_cnt++; | ||
} | ||
} | ||
return ret_cnt; | ||
} | ||
|
||
bool | ||
isValidLabel(LabelType label) override { | ||
std::unique_lock<std::mutex> lock_table(label_lookup_lock); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to keep this interface anymore; it is just a special form of batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This involves a lot of test files. Can we leave it unchanged for now?