-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a batch interface for getDistanceByLabel #309
Changes from all commits
0165bcf
ff71e34
d32e60a
b08807d
e8f3a92
b720efc
47a2fea
01d8d31
c0c2cbb
5610733
9059ef1
be4fbd2
f4a330a
a41a874
cd24b67
a519cd4
338d71b
9756f58
26b329e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,6 +171,31 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { | |
return dist; | ||
} | ||
|
||
tl::expected<vsag::DatasetPtr, vsag::Error> | ||
HierarchicalNSW::getBatchDistanceByLabel(int64_t count, | ||
const int64_t *vids, | ||
const void* data_point) { | ||
std::shared_lock lock_table(label_lookup_lock_); | ||
int64_t valid_cnt = 0; | ||
auto result = vsag::Dataset::Make(); | ||
result->Owner(true, allocator_); | ||
auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); | ||
result->Distances(distances); | ||
for (int i = 0; i < count; i++) { | ||
auto search = label_lookup_.find(vids[i]); | ||
if (search == label_lookup_.end()) { | ||
distances[i] = -1; | ||
} else { | ||
InnerIdType internal_id = search->second; | ||
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have cal distance by id interface, why not use it
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to calculate distances in batches. Calling getDistanceByLabel requires a shared lock for each LabelType, which may have a certain impact on performance in large batch scenarios. |
||
distances[i] = dist; | ||
valid_cnt++; | ||
} | ||
} | ||
result->NumElements(valid_cnt); | ||
return std::move(result); | ||
} | ||
|
||
bool | ||
HierarchicalNSW::isValidLabel(LabelType label) { | ||
std::shared_lock lock_table(label_lookup_lock_); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -262,6 +262,31 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> { | |
return dist; | ||
} | ||
|
||
tl::expected<vsag::DatasetPtr, vsag::Error> | ||
getBatchDistanceByLabel(int64_t count, | ||
const int64_t *vids, | ||
const void* data_point) override { | ||
std::unique_lock<std::mutex> lock_table(label_lookup_lock); | ||
int64_t valid_cnt = 0; | ||
auto result = vsag::Dataset::Make(); | ||
result->Owner(true, allocator_); | ||
auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
result->Distances(distances); | ||
for (int i = 0; i < count; i++) { | ||
auto search = label_lookup_.find(vids[i]); | ||
if (search == label_lookup_.end()) { | ||
distances[i] = -1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. explicitly specify at the interface that -1 indicates an invalid distance There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already marked |
||
} else { | ||
InnerIdType internal_id = search->second; | ||
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); | ||
distances[i] = dist; | ||
valid_cnt++; | ||
} | ||
} | ||
result->NumElements(valid_cnt); | ||
return std::move(result); | ||
} | ||
|
||
bool | ||
isValidLabel(LabelType label) override { | ||
std::unique_lock<std::mutex> lock_table(label_lookup_lock); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to keep this interface anymore; it is just a special form of batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This involves a lot of test files. Can we leave it unchanged for now?