Skip to content

Commit

Permalink
add bitmap filter
Browse files Browse the repository at this point in the history
Signed-off-by: zourunxin.zrx <[email protected]>
  • Loading branch information
zourunxin.zrx committed Jan 21, 2025
1 parent 892cbe0 commit ff95558
Show file tree
Hide file tree
Showing 17 changed files with 418 additions and 27 deletions.
46 changes: 45 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,48 @@ data/
/python/pyvsag/_version.py
/coverage-report/
/coverage/
*.gcov
*.gcov

.ccls-cache/
CPackConfig.cmake
CPackSourceConfig.cmake
boost/
examples/cpp/Makefile
examples/cpp/todo_examples/Makefile
instructions_test_avx.cpp
instructions_test_avx2.cpp
instructions_test_avx512.cpp
instructions_test_sse.cpp
mockimpl/Makefile
openblas/
spdlog/
src/Makefile
src/io/Makefile
src/quantization/Makefile
src/simd/Makefile
tests/Makefile
examples/cpp/custom_logger
examples/cpp/custom_memory_allocator
examples/cpp/example_conjugate_graph
examples/cpp/example_diskann
examples/cpp/example_engine
examples/cpp/example_fresh_hnsw
examples/cpp/example_hnsw
examples/cpp/simple_hgraph_sq8
examples/cpp/simple_hnsw
examples/cpp/simple_pyramid
examples/cpp/todo_examples/101_index_hnsw
examples/cpp/todo_examples/102_index_diskann
examples/cpp/todo_examples/103_index_hgraph
examples/cpp/todo_examples/104_index_fresh_hnsw
examples/cpp/todo_examples/105_index_brute_force
examples/cpp/todo_examples/301_feature_filter
examples/cpp/todo_examples/302_feature_range_search
examples/cpp/todo_examples/304_feature_enhance_graph
examples/cpp/todo_examples/305_feature_remove
examples/cpp/vsag_ext_example
mockimpl/example_diskann_mockimpl
mockimpl/example_hnsw_mockimpl
mockimpl/tests_mockimpl
tests/functests
tests/unittests
5 changes: 3 additions & 2 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,13 @@ class Index {
* - num_elements: 1
* - ids, distances: length is (num_elements * k)
*/

virtual tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const = 0;

const std::function<bool(int64_t)>& filter,
const int64_t totalValid = 0) const = 0;
/**
* @brief Performing single range search on index
*
Expand Down
3 changes: 2 additions & 1 deletion mockimpl/vsag/simpleflat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ tl::expected<DatasetPtr, Error>
SimpleFlat::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const {
const std::function<bool(int64_t)>& filter,
const int64_t totalValid) const {
int64_t dim = query->GetDim();
k = std::min(k, GetNumElements());
int64_t num_elements = query->GetNumElements();
Expand Down
3 changes: 2 additions & 1 deletion mockimpl/vsag/simpleflat.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class SimpleFlat : public Index {
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const override;
const std::function<bool(int64_t)>& filter,
const int64_t totalValid = 0) const override;

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
Expand Down
3 changes: 2 additions & 1 deletion src/algorithm/hnswlib/algorithm_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class AlgorithmInterface {
searchKnn(const void*,
size_t,
size_t,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
vsag::BaseFilterFunctor* isIdAllowed = nullptr,
const int64_t totalValid = 0) const = 0;

virtual std::priority_queue<std::pair<dist_t, LabelType>>
searchRange(const void*,
Expand Down
9 changes: 3 additions & 6 deletions src/algorithm/hnswlib/block_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,9 @@ BlockManager::~BlockManager() {

char*
BlockManager::GetElementPtr(size_t index, size_t offset) {
if (index >= max_elements_) {
throw std::out_of_range("Index is out of range:" + std::to_string(index));
}

size_t block_index = (index * size_data_per_element_) / block_size_;
size_t offset_in_block = (index * size_data_per_element_) % block_size_;
size_t index_size_product = index * size_data_per_element_;
size_t block_index = index_size_product / block_size_;
size_t offset_in_block = index_size_product % block_size_;
return blocks_[block_index] + offset_in_block + offset;
}

Expand Down
106 changes: 103 additions & 3 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// limitations under the License.

#include "hnswalg.h"
#include "neighbor.h"

#include <memory>
namespace hnswlib {
Expand Down Expand Up @@ -371,6 +372,104 @@ HierarchicalNSW::searchBaseLayer(InnerIdType ep_id, const void* data_point, int
return top_candidates;
}

template <bool has_deletions, bool collect_metrics>
MaxHeap
HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
const void* data_point,
size_t ef,
vsag::BaseFilterFunctor* isIdAllowed,
const int64_t totalValid,
const size_t k) const {
VisitedListPtr vl = visited_list_pool_->getFreeVisitedList();
vl_type* visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

MaxHeap top_candidates(allocator_); // 结果队列
NeighborSetDoublePopList retset(ef);

float dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
if ((!has_deletions || !isMarkedDeleted(ep_id)) &&
((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
retset.insert(Neighbor(ep_id, dist, Neighbor::kValid));
} else {
retset.insert(Neighbor(ep_id, dist, Neighbor::kInvalid));
}

visited_array[ep_id] = visited_array_tag;
const int64_t total_cnt = max_elements_;
float kAlpha = 1.0 - ((float)totalValid / (float)total_cnt);
float accumulative_alpha = 1;
auto add_search_candidate = [&](Neighbor n) { return retset.insert(n, nullptr); };

while (retset.has_next()) {
auto [u, d, s] = retset.pop(); // id, distance, status

InnerIdType current_node_id = u;
auto link_data = getLinklistAtLevelWithLock(current_node_id, 0);
int* data = (int*)link_data.get();
size_t size = getListCount((linklistsizeint*)data);
if (collect_metrics) {
metric_hops_++;
metric_distance_computations_ += size;
}

auto vector_data_ptr = data_level0_memory_->GetElementPtr((*(data + 1)), offset_data_);
#ifdef USE_SSE
_mm_prefetch((char*)(visited_array + *(data + 1)), _MM_HINT_T0);
_mm_prefetch((char*)(visited_array + *(data + 1) + 64), _MM_HINT_T0);
_mm_prefetch(vector_data_ptr, _MM_HINT_T0);
_mm_prefetch((char*)(data + 2), _MM_HINT_T0);
#endif

for (size_t j = 1; j <= size; j++) {
int candidate_id = *(data + j);
size_t pre_l = std::min(j, size - 2);
vector_data_ptr =
data_level0_memory_->GetElementPtr((*(data + pre_l + 1)), offset_data_);
#ifdef USE_SSE
_mm_prefetch((char*)(visited_array + *(data + pre_l + 1)), _MM_HINT_T0);
_mm_prefetch(vector_data_ptr, _MM_HINT_T0); ////////////
#endif
int status = Neighbor::kValid;
if (visited_array[candidate_id] != visited_array_tag) {
visited_array[candidate_id] = visited_array_tag;

// invalid
if ((has_deletions && isMarkedDeleted(candidate_id)) ||
(isIdAllowed && !(*isIdAllowed)(getExternalLabel(candidate_id)))) {
status = Neighbor::kInvalid;
accumulative_alpha += kAlpha;
if (accumulative_alpha < 1.0f) {
continue;// 剪枝
}
accumulative_alpha -= 1.0f;
}

char* currObj1 = (getDataByInternalId(candidate_id));
float dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
Neighbor nn(candidate_id, dist, status);
if (add_search_candidate(nn)) {
#ifdef USE_SSE
// todo: _mm_prefetch
#endif
}
}
}
}
// TODO
// if (retset.size() < k) { // 不足的部分暴搜
// return searchKnnBF(query_data, k, bitset);
// }

size_t len = std::min(k, retset.size());
for (int i = 0; i < len; ++i) {
top_candidates.emplace(dist, (LabelType)retset[i].id);
}

visited_list_pool_->releaseVisitedList(vl);
return top_candidates;
}

template <bool has_deletions, bool collect_metrics>
MaxHeap
HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
Expand Down Expand Up @@ -1363,7 +1462,8 @@ std::priority_queue<std::pair<float, LabelType>>
HierarchicalNSW::searchKnn(const void* query_data,
size_t k,
uint64_t ef,
vsag::BaseFilterFunctor* isIdAllowed) const {
vsag::BaseFilterFunctor* isIdAllowed,
const int64_t totalValid) const {
std::shared_lock resize_lock(resize_mutex_);
std::priority_queue<std::pair<float, LabelType>> result;
if (cur_element_count_ == 0)
Expand Down Expand Up @@ -1410,10 +1510,10 @@ HierarchicalNSW::searchKnn(const void* query_data,

if (num_deleted_ == 0) {
top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), isIdAllowed, totalValid, k);
} else {
top_candidates =
searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, k), isIdAllowed, totalValid, k);
}

while (top_candidates.size() > k) {
Expand Down
11 changes: 10 additions & 1 deletion src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ class HierarchicalNSW : public AlgorithmInterface<float> {

template <bool has_deletions, bool collect_metrics = false>
MaxHeap
searchBaseLayerST(InnerIdType ep_id,
const void* data_point,
size_t ef,
vsag::BaseFilterFunctor* isIdAllowed,
const int64_t totalValid,
const size_t k) const;
template <bool has_deletions, bool collect_metrics = false>
MaxHeap
searchBaseLayerST(InnerIdType ep_id,
const void* data_point,
float radius,
Expand Down Expand Up @@ -384,7 +392,8 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
searchKnn(const void* query_data,
size_t k,
uint64_t ef,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const override;
vsag::BaseFilterFunctor* isIdAllowed = nullptr,
const int64_t totalValid = 0) const override;

std::priority_queue<std::pair<float, LabelType>>
searchRange(const void* query_data,
Expand Down
3 changes: 2 additions & 1 deletion src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,8 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
searchKnn(const void* query_data,
size_t k,
uint64_t ef,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const override {
vsag::BaseFilterFunctor* isIdAllowed = nullptr,
const int64_t totalValid = 0) const override {
std::priority_queue<std::pair<float, LabelType>> result;
if (cur_element_count_ == 0)
return result;
Expand Down
Loading

0 comments on commit ff95558

Please sign in to comment.