-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a batch interface for getDistanceByLabel #337
base: main
Are you sure you want to change the base?
Conversation
source: #309 |
Codecov ReportAttention: Patch coverage is @@ Coverage Diff @@
## main #337 +/- ##
==========================================
- Coverage 90.94% 87.49% -3.45%
==========================================
Files 133 134 +1
Lines 8459 8639 +180
==========================================
- Hits 7693 7559 -134
- Misses 766 1080 +314
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report in Codecov by Sentry.
|
include/vsag/index.h
Outdated
* @return result is valid distance of input vids. '-1' indicates an invalid distance. | ||
*/ | ||
virtual tl::expected<DatasetPtr, Error> | ||
CalcBatchDistanceById(int64_t count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend maintaining consistency with the param style of the CalcDistanceById
function
CalcBatchDistanceById(const float* vector, const int64_t *ids, int64_t len) const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts, edited
@@ -66,6 +69,11 @@ class AlgorithmInterface { | |||
virtual float | |||
getDistanceByLabel(LabelType label, const void* data_point) = 0; | |||
|
|||
virtual tl::expected<vsag::DatasetPtr, vsag::Error> | |||
getBatchDistanceByLabel(int64_t count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
src/algorithm/hnswlib/hnswalg.cpp
Outdated
valid_cnt++; | ||
} | ||
} | ||
result->NumElements(valid_cnt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumElements(count)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the input parameter is count ids, not every id has found a result. valid_cnt is intended to express the number of ids actually found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that NumElements
is used to indicate how many records will be in the result DataSet, and which data is valid is indicated by the flag distances[i] = -1
. Otherwise, there would be no way to specify how many rows of results are in the result DataSet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
src/algorithm/hnswlib/hnswalg.cpp
Outdated
distances[i] = -1; | ||
} else { | ||
InnerIdType internal_id = search->second; | ||
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_point
need normalize ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
src/index/hnsw.h
Outdated
@@ -145,6 +145,13 @@ class HNSW : public Index { | |||
SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); | |||
}; | |||
|
|||
virtual tl::expected<DatasetPtr, Error> | |||
CalcBatchDistanceById(int64_t count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
959db4a
to
aaab6e1
Compare
aaab6e1
to
c715656
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: zourunxin.zrx <[email protected]>
9fcab4f
to
42f054e
Compare
Add a batch interface for getDistanceByLabel
virtual tl::expected<int64_t, Error>CalcBatchDistanceById(int64_t count, const int64_t vids, const float vector, float *&distances)