Skip to content

Commit

Permalink
Expose Vamana build parameters for static and dynamic indices (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
ibhati authored Dec 19, 2024
1 parent 5250fbd commit f4f7bc1
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 36 deletions.
5 changes: 5 additions & 0 deletions bindings/python/src/vamana.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ void wrap(py::module& m) {
.def_readwrite(
"max_candidate_pool_size",
&svs::index::vamana::VamanaBuildParameters::max_candidate_pool_size
)
.def_readwrite("prune_to", &svs::index::vamana::VamanaBuildParameters::prune_to)
.def_readwrite(
"use_full_search_history",
&svs::index::vamana::VamanaBuildParameters::use_full_search_history
);

///
Expand Down
42 changes: 19 additions & 23 deletions include/svs/index/vamana/dynamic_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,42 +293,38 @@ class MutableVamanaIndex {

///// Accessors

///
/// @brief Get the alpha value used for pruning while mutating the graph.
///
/// @see set_alpha, get_construction_window_size, set_construction_window_size
///
float get_alpha() const { return alpha_; }

///
/// @brief Set the alpha value used for pruning while mutating the graph.
///
/// @see get_alpha, get_construction_window_size, set_construction_window_size
///
void set_alpha(float alpha) { alpha_ = alpha; }

/// @brief Get the ``graph_max_degree`` used while mutating the graph.
size_t get_graph_max_degree() const { return graph_.max_degree(); }

/// @brief Get the max candidate pool size used while mutating the graph.
size_t get_max_candidates() const { return max_candidates_; }
/// @brief Set the max candidate pool size to be used while mutating the graph.
void set_max_candidates(size_t max_candidates) { max_candidates_ = max_candidates; }
/// @brief Get the prune_to value used while mutating the graph.
size_t get_prune_to() const { return prune_to_; }
/// @brief Set the prune_to value to be used while mutating the graph.
void set_prune_to(size_t prune_to) { prune_to_ = prune_to; }

void set_full_search_history(bool enable) { use_full_search_history_ = enable; }
bool get_full_search_history() const { return use_full_search_history_; }

///
/// @brief Get the window size used the mutating the graph.
///
/// @see set_construction_window_size, get_alpha, set_alpha
///
/// @brief Get the window size used while mutating the graph.
size_t get_construction_window_size() const { return construction_window_size_; }

///
/// @brief Set the window size used the mutating the graph.
///
/// @see get_construction_window_size, get_alpha, set_alpha
///
/// @brief Set the window size to be used while mutating the graph.
void set_construction_window_size(size_t window_size) {
construction_window_size_ = window_size;
}

/// @brief Return whether the full search history is being used while mutating
/// the graph.
bool get_full_search_history() const { return use_full_search_history_; }
/// @brief Enable using the full search history for candidate generation while
/// mutating the graph.
void set_full_search_history(bool enable) { use_full_search_history_ = enable; }


///// Index translation.

///
Expand Down
30 changes: 19 additions & 11 deletions include/svs/index/vamana/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,35 +718,43 @@ class VamanaIndex {

///// Parameter manipulation.

/// @brief Return the value of ``alpha`` used during graph construction.
/// @brief Get the value of ``alpha`` used during graph construction.
float get_alpha() const { return build_parameters_.alpha; }
/// @brief Set the value of ``alpha`` to be used during graph construction.
void set_alpha(float alpha) { build_parameters_.alpha = alpha; }

/// @brief Return the max candidate pool size that was used for graph construction.
/// @brief Get the ``graph_max_degree`` that was used for graph construction.
size_t get_graph_max_degree() const { return graph_.max_degree(); }

/// @brief Get the max candidate pool size that was used for graph construction.
size_t get_max_candidates() const { return build_parameters_.max_candidate_pool_size; }
/// @brief Set the max candidate pool size to be used for graph construction.
void set_max_candidates(size_t max_candidates) {
build_parameters_.max_candidate_pool_size = max_candidates;
}

/// @brief Return the search window size that was used for graph construction.
/// @brief Get the prune_to value that was used for graph construction.
size_t get_prune_to() const { return build_parameters_.prune_to; }
/// @brief Set the prune_to value to be used for graph construction.
void set_prune_to(size_t prune_to) { build_parameters_.prune_to = prune_to; }

/// @brief Get the search window size that was used for graph construction.
size_t get_construction_window_size() const { return build_parameters_.window_size; }
/// @brief Set the search window size to be used for graph construction.
void set_construction_window_size(size_t construction_window_size) {
build_parameters_.window_size = construction_window_size;
}

///
/// @brief Enable using the full search history for candidate generation while
/// building.
///
void set_full_search_history(bool enable) {
build_parameters_.use_full_search_history = enable;
}

/// @brief Return whether the full search history is being used for index
/// construction.
bool get_full_search_history() const {
return build_parameters_.use_full_search_history;
}
/// @brief Enable using the full search history for candidate generation while
/// building.
void set_full_search_history(bool enable) {
build_parameters_.use_full_search_history = enable;
}

///// Saving

Expand Down
20 changes: 20 additions & 0 deletions include/svs/orchestrators/dynamic_vamana.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,35 @@ class DynamicVamana : public manager::IndexManager<DynamicVamanaInterface> {
}

// Accessors
/// @copydoc svs::index::vamana::MutableVamanaIndex::get_alpha
float get_alpha() const { return impl_->get_alpha(); }
void set_alpha(size_t alpha) { impl_->set_alpha(alpha); }

/// @copydoc svs::index::vamana::MutableVamanaIndex::get_graph_max_degree
size_t get_graph_max_degree() const { return impl_->get_graph_max_degree(); }

/// @copydoc svs::index::vamana::MutableVamanaIndex::set_construction_window_size
size_t get_construction_window_size() const {
return impl_->get_construction_window_size();
}
void set_construction_window_size(size_t window_size) {
impl_->set_construction_window_size(window_size);
}

/// @copydoc svs::index::vamana::MutableVamanaIndex::get_max_candidates
size_t get_max_candidates() const { return impl_->get_max_candidates(); }
void set_max_candidates(size_t max_candidates) {
impl_->set_max_candidates(max_candidates);
}

/// @copydoc svs::index::vamana::MutableVamanaIndex::get_prune_to
size_t get_prune_to() const { return impl_->get_prune_to(); }
void set_prune_to(size_t prune_to) { impl_->set_prune_to(prune_to); }

/// @copydoc svs::index::vamana::MutableVamanaIndex::get_full_search_history
bool get_full_search_history() const { return impl_->get_full_search_history(); }
void set_full_search_history(bool enable) { impl_->set_full_search_history(enable); }

// Backend String
std::string experimental_backend_string() const {
return impl_->experimental_backend_string();
Expand Down
33 changes: 32 additions & 1 deletion include/svs/orchestrators/vamana.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,20 @@ class VamanaInterface {
virtual void set_alpha(float alpha) = 0;
virtual float get_alpha() const = 0;

virtual size_t get_graph_max_degree() const = 0;

virtual void set_construction_window_size(size_t window_size) = 0;
virtual size_t get_construction_window_size() const = 0;

virtual void set_max_candidates(size_t max_candidates) = 0;
virtual size_t get_max_candidates() const = 0;

virtual void set_prune_to(size_t prune_to) = 0;
virtual size_t get_prune_to() const = 0;

virtual void set_full_search_history(bool enable) = 0;
virtual bool get_full_search_history() const = 0;

///// Backend Information Interface
virtual std::string experimental_backend_string() const = 0;

Expand Down Expand Up @@ -110,6 +118,8 @@ class VamanaImpl : public manager::ManagerImpl<QueryTypes, Impl, IFace> {
void set_alpha(float alpha) override { impl().set_alpha(alpha); }
float get_alpha() const override { return impl().get_alpha(); }

size_t get_graph_max_degree() const override { return impl().get_graph_max_degree(); }

void set_construction_window_size(size_t window_size) override {
impl().set_construction_window_size(window_size);
}
Expand All @@ -122,6 +132,16 @@ class VamanaImpl : public manager::ManagerImpl<QueryTypes, Impl, IFace> {
}
size_t get_max_candidates() const override { return impl().get_max_candidates(); }

void set_prune_to(size_t prune_to) override { impl().set_prune_to(prune_to); }
size_t get_prune_to() const override { return impl().get_prune_to(); }

void set_full_search_history(bool enable) override {
impl().set_full_search_history(enable);
}
bool get_full_search_history() const override {
return impl().get_full_search_history();
}

///// Backend Information Interface
std::string experimental_backend_string() const override {
return std::string{typename_impl.begin(), typename_impl.end() - 1};
Expand Down Expand Up @@ -276,7 +296,10 @@ class Vamana : public manager::IndexManager<VamanaInterface> {
float get_alpha() const { return impl_->get_alpha(); }
void set_alpha(float alpha) { impl_->set_alpha(alpha); }

/// @copydoc svs::index::vamana::VamanaIndex::set_alpha
/// @copydoc svs::index::vamana::VamanaIndex::get_graph_max_degree
size_t get_graph_max_degree() const { return impl_->get_graph_max_degree(); }

/// @copydoc svs::index::vamana::VamanaIndex::set_construction_window_size
size_t get_construction_window_size() const {
return impl_->get_construction_window_size();
}
Expand All @@ -290,6 +313,14 @@ class Vamana : public manager::IndexManager<VamanaInterface> {
impl_->set_max_candidates(max_candidates);
}

/// @copydoc svs::index::vamana::VamanaIndex::get_prune_to
size_t get_prune_to() const { return impl_->get_prune_to(); }
void set_prune_to(size_t prune_to) { impl_->set_prune_to(prune_to); }

/// @copydoc svs::index::vamana::VamanaIndex::get_full_search_history
bool get_full_search_history() const { return impl_->get_full_search_history(); }
void set_full_search_history(bool enable) { impl_->set_full_search_history(enable); }

bool visited_set_enabled() const {
return get_search_parameters().search_buffer_visited_set_;
}
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/vamana/index_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") {
index.set_construction_window_size(456);
index.set_max_candidates(1001);

auto max_degree = index.get_graph_max_degree();
index.set_prune_to(max_degree - 2);
index.set_full_search_history(false);

auto config_dir = temp_dir / "config";
auto graph_dir = temp_dir / "graph";
auto data_dir = temp_dir / "data";
Expand All @@ -299,6 +303,9 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") {
CATCH_REQUIRE(index.get_alpha() == 1.2f);
CATCH_REQUIRE(index.get_construction_window_size() == 456);
CATCH_REQUIRE(index.get_max_candidates() == 1001);
CATCH_REQUIRE(index.get_graph_max_degree() == max_degree);
CATCH_REQUIRE(index.get_prune_to() == max_degree - 2);
CATCH_REQUIRE(index.get_full_search_history() == false);

index.set_num_threads(2);
run_tests(index, queries, groundtruth, expected_results.config_and_recall_);
Expand Down
19 changes: 18 additions & 1 deletion tests/svs/index/vamana/dynamic_index_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,20 +305,34 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") {
double build_time = svs::lib::time_difference(tic);
index.debug_check_invariants(false);

// Verify that we can get and set alpha and the construction window size.
// Verify that we can get and set build parameters.
CATCH_REQUIRE(index.get_alpha() == alpha);
index.set_alpha(1.0);
CATCH_REQUIRE(index.get_alpha() == 1.0);
index.set_alpha(alpha);
CATCH_REQUIRE(index.get_alpha() == alpha);

CATCH_REQUIRE(index.get_graph_max_degree() == max_degree);

const size_t expected_construction_window = 2 * max_degree;
CATCH_REQUIRE(index.get_construction_window_size() == expected_construction_window);
index.set_construction_window_size(10);
CATCH_REQUIRE(index.get_construction_window_size() == 10);
index.set_construction_window_size(expected_construction_window);
CATCH_REQUIRE(index.get_construction_window_size() == expected_construction_window);

CATCH_REQUIRE(index.get_max_candidates() == 1000);
index.set_max_candidates(750);
CATCH_REQUIRE(index.get_max_candidates() == 750);

CATCH_REQUIRE(index.get_prune_to() == max_degree - 4);
index.set_prune_to(max_degree - 2);
CATCH_REQUIRE(index.get_prune_to() == max_degree - 2);

CATCH_REQUIRE(index.get_full_search_history() == true);
index.set_full_search_history(false);
CATCH_REQUIRE(index.get_full_search_history() == false);

reference.configure_extra_checks(true);
CATCH_REQUIRE(reference.extra_checks_enabled());

Expand Down Expand Up @@ -348,10 +362,13 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") {

// Make sure parameters were saved across the saving.
CATCH_REQUIRE(index.get_alpha() == reloaded.get_alpha());
CATCH_REQUIRE(index.get_graph_max_degree() == reloaded.get_graph_max_degree());
CATCH_REQUIRE(index.get_max_candidates() == reloaded.get_max_candidates());
CATCH_REQUIRE(
index.get_construction_window_size() == reloaded.get_construction_window_size()
);
CATCH_REQUIRE(index.get_prune_to() == reloaded.get_prune_to());
CATCH_REQUIRE(index.get_full_search_history() == reloaded.get_full_search_history());
CATCH_REQUIRE(index.size() == reloaded.size());
// ID's preserved across runs.
index.on_ids([&](size_t e) { CATCH_REQUIRE(reloaded.has_id(e)); });
Expand Down

0 comments on commit f4f7bc1

Please sign in to comment.