Skip to content

Commit

Permalink
support basic searcher
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou committed Jan 20, 2025
1 parent e87defc commit 0d51c38
Show file tree
Hide file tree
Showing 11 changed files with 675 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,4 +1548,10 @@ HierarchicalNSW::checkIntegrity() {
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
}

template MaxHeap
HierarchicalNSW::searchBaseLayerST<false, false>(InnerIdType ep_id,
const void* data_point,
size_t ef,
vsag::BaseFilterFunctor* isIdAllowed) const;

} // namespace hnswlib
11 changes: 11 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
bool
isValidLabel(LabelType label) override;

size_t
getMaxDegree() {
return maxM0_;
};

linklistsizeint*
get_linklist0(InnerIdType internal_id) const {
// only for test now
return (linklistsizeint*)(data_level0_memory_->GetElementPtr(internal_id, offsetLevel0_));
}

inline LabelType
getExternalLabel(InnerIdType internal_id) const {
std::shared_lock lock(points_locks_[internal_id]);
Expand Down
13 changes: 13 additions & 0 deletions src/allocator_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,17 @@ class AllocatorWrapper {

Allocator* allocator_{};
};

template <typename T, typename U>
bool
operator==(const AllocatorWrapper<T>&, const AllocatorWrapper<U>&) noexcept {
return true;
}

template <typename T, typename U>
bool
operator!=(const AllocatorWrapper<T>& a, const AllocatorWrapper<U>& b) noexcept {
return !(a == b);
}

} // namespace vsag
55 changes: 55 additions & 0 deletions src/data_cell/adapter_graph_datacell.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "algorithm/hnswlib/hnswalg.h"
#include "algorithm/hnswlib/space_l2.h"

namespace vsag {
class AdaptGraphDataCell {
public:
AdaptGraphDataCell(std::shared_ptr<hnswlib::HierarchicalNSW> alg_hnsw) : alg_hnsw_(alg_hnsw){};

void
GetNeighbors(InnerIdType id, Vector<InnerIdType>& neighbor_ids) {
int* data = (int*)alg_hnsw_->get_linklist0(id);
uint32_t size = alg_hnsw_->getListCount((hnswlib::linklistsizeint*)data);
neighbor_ids.resize(size);
for (uint32_t i = 0; i < size; i++) {
neighbor_ids[i] = *(data + i + 1);
}
}

uint32_t
GetNeighborSize(InnerIdType id) {
int* data = (int*)alg_hnsw_->get_linklist0(id);
return alg_hnsw_->getListCount((hnswlib::linklistsizeint*)data);
}

void
Prefetch(InnerIdType id, InnerIdType neighbor_i) {
int* data = (int*)alg_hnsw_->get_linklist0(id);
_mm_prefetch(data + neighbor_i + 1, _MM_HINT_T0);
}

uint32_t
MaximumDegree() {
return alg_hnsw_->getMaxDegree();
}

private:
std::shared_ptr<hnswlib::HierarchicalNSW> alg_hnsw_;
};
} // namespace vsag
68 changes: 68 additions & 0 deletions src/data_cell/adapter_graph_datacell_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "adapter_graph_datacell.h"

#include "catch2/catch_template_test_macros.hpp"
#include "fixtures.h"
#include "fmt/format-inl.h"
#include "graph_interface_test.h"
#include "io/io_headers.h"
#include "safe_allocator.h"

using namespace vsag;

TEST_CASE("basic usage for graph data cell (adapter of hnsw)", "[ut][GraphDataCell]") {
uint32_t M = 32;
uint32_t data_size = 1000;
uint32_t ef_construction = 100;
uint64_t DEFAULT_MAX_ELEMENT = 1;
uint64_t dim = 960;
auto vectors = fixtures::generate_vectors(data_size, dim);
std::vector<int64_t> ids(data_size);
std::iota(ids.begin(), ids.end(), 0);

auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto space = std::make_shared<hnswlib::L2Space>(dim);
auto io = std::make_shared<MemoryIO>(allocator.get());
auto alg_hnsw =
std::make_shared<hnswlib::HierarchicalNSW>(space.get(),
DEFAULT_MAX_ELEMENT,
allocator.get(),
M / 2,
ef_construction,
Options::Instance().block_size_limit());
alg_hnsw->init_memory_space();
for (int64_t i = 0; i < data_size; ++i) {
auto successful_insert =
alg_hnsw->addPoint((const void*)(vectors.data() + i * dim), ids[i]);
REQUIRE(successful_insert == true);
}

auto graph_data_cell = std::make_shared<AdaptGraphDataCell>(alg_hnsw);

for (uint32_t i = 0; i < data_size; i++) {
auto neighbor_size = graph_data_cell->GetNeighborSize(i);
Vector<InnerIdType> neighbor_ids(neighbor_size, allocator.get());
graph_data_cell->GetNeighbors(i, neighbor_ids);

int* data = (int*)alg_hnsw->get_linklist0(i);
REQUIRE(neighbor_size == alg_hnsw->getListCount((hnswlib::linklistsizeint*)data));

for (uint32_t j = 0; j < neighbor_size; j++) {
REQUIRE(neighbor_ids[j] == *(data + j + 1));
}
}
}
16 changes: 16 additions & 0 deletions src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class FlattenDataCell : public FlattenInterface {
io_->Prefetch(id * code_size_);
};

bool
Decode(const uint8_t* codes, DataType* data) override {
return this->quantizer_->DecodeOne(codes, data);
}

[[nodiscard]] std::string
GetQuantizerName() override;

Expand Down Expand Up @@ -226,7 +231,18 @@ FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
const std::shared_ptr<Computer<QuantTmpl>>& computer,
const InnerIdType* idx,
InnerIdType id_count) {
for (uint32_t i = 0; i < this->prefetch_neighbor_codes_num_ and i < id_count; i++) {
this->io_->Prefetch(static_cast<uint64_t>(idx[i]) * static_cast<uint64_t>(code_size_),
this->prefetch_cache_line_);
}

for (int64_t i = 0; i < id_count; ++i) {
if (i + this->prefetch_neighbor_codes_num_ < id_count) {
this->io_->Prefetch(static_cast<uint64_t>(idx[i + this->prefetch_neighbor_codes_num_]) *
static_cast<uint64_t>(code_size_),
this->prefetch_cache_line_);
}

bool release = false;
const auto* codes = this->GetCodesById(idx[i], release);
computer->ComputeDist(codes, result_dists + i);
Expand Down
15 changes: 15 additions & 0 deletions src/data_cell/flatten_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
#include <nlohmann/json.hpp>
#include <string>

#include "impl/runtime_parameter.h"
#include "index/index_common_param.h"
#include "quantization/computer.h"
#include "stream_reader.h"
#include "stream_writer.h"
#include "typing.h"
#include "vsag/constants.h"

namespace vsag {
class FlattenInterface;
Expand Down Expand Up @@ -83,6 +85,11 @@ class FlattenInterface {
return false;
}

virtual bool
Decode(const uint8_t* codes, DataType* vector) {
return false;
}

[[nodiscard]] virtual InnerIdType
TotalCount() const {
return this->total_count_;
Expand All @@ -102,10 +109,18 @@ class FlattenInterface {
StreamReader::ReadObj(reader, this->code_size_);
}

virtual void
SetRuntimeParameters(const UnorderedMap<std::string, ParamValue>& new_params) {
// TODO(ZXY): implement
return;
}

public:
InnerIdType total_count_{0};
InnerIdType max_capacity_{1000000};
uint32_t code_size_{0};
uint32_t prefetch_neighbor_codes_num_{1};
uint32_t prefetch_cache_line_{1};
};

} // namespace vsag
Loading

0 comments on commit 0d51c38

Please sign in to comment.