Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a batch interface for getDistanceByLabel #337

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Carrot-77
Copy link

Add a batch interface for getDistanceByLabel

  • param count is the count of vids
  • param vids is the unique identifier of the vector to be calculated in the index.
  • param vector is the embedding of query
  • param distances is the distances between the query and the vector of the given ID
  • return result is valid distance of input vids.

virtual tl::expected<int64_t, Error>CalcBatchDistanceById(int64_t count, const int64_t vids, const float vector, float *&distances)

@Carrot-77
Copy link
Author

source: #309

Copy link

codecov bot commented Jan 16, 2025

Codecov Report

Attention: Patch coverage is 50.00000% with 21 lines in your changes missing coverage. Please review.

@@            Coverage Diff             @@
##             main     #337      +/-   ##
==========================================
- Coverage   90.94%   87.49%   -3.45%     
==========================================
  Files         133      134       +1     
  Lines        8459     8639     +180     
==========================================
- Hits         7693     7559     -134     
- Misses        766     1080     +314     
Flag Coverage Δ
cpp 87.49% <50.00%> (-3.45%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Components Coverage Δ
common 88.57% <ø> (-7.18%) ⬇️
datacell 90.18% <ø> (-1.60%) ⬇️
index 90.72% <52.50%> (-0.29%) ⬇️
simd 69.28% <ø> (-12.35%) ⬇️

Continue to review full report in Codecov by Sentry.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2b9ddac...42f054e. Read the comment docs.

* @return result is valid distance of input vids. '-1' indicates an invalid distance.
*/
virtual tl::expected<DatasetPtr, Error>
CalcBatchDistanceById(int64_t count,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend maintaining consistency with the param style of the CalcDistanceById function

CalcBatchDistanceById(const float* vector, const int64_t *ids, int64_t len) const

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts, edited

@@ -66,6 +69,11 @@ class AlgorithmInterface {
virtual float
getDistanceByLabel(LabelType label, const void* data_point) = 0;

virtual tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(int64_t count,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

valid_cnt++;
}
}
result->NumElements(valid_cnt);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumElements(count) ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the input parameter is count ids, not every id has found a result. valid_cnt is intended to express the number of ids actually found.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that NumElements is used to indicate how many records will be in the result DataSet, and which data is valid is indicated by the flag distances[i] = -1. Otherwise, there would be no way to specify how many rows of results are in the result DataSet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_point need normalize ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

src/index/hnsw.h Outdated
@@ -145,6 +145,13 @@ class HNSW : public Index {
SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector));
};

virtual tl::expected<DatasetPtr, Error>
CalcBatchDistanceById(int64_t count,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

Copy link
Collaborator

@LHT129 LHT129 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@inabao inabao added invalid This doesn't seem right version/0.14 and removed version/0.13 invalid This doesn't seem right labels Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants