Skip to content

Commit

Permalink
[pysvs] Allow 1-dimensional queries for search. (#17)
Browse files Browse the repository at this point in the history
Previously, it was assumed that all queries were given in a
two-dimensional numpy array. This relaxes that requirements to treat
one-dimensional arguments as a single query.
  • Loading branch information
Mark Hildebrand authored Dec 7, 2023
1 parent f010178 commit b7e5488
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 5 deletions.
38 changes: 38 additions & 0 deletions bindings/python/src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ template <typename T> std::span<const T> as_span(const py_contiguous_array_t<T>&
return std::span<const T>(array.data(), array.size());
}

struct AllowVectorsTag {};

/// A property to pass to ``data_view`` to interpret a numpy vector as a 2D array with
/// the size of the first dimension equal to zero.
inline constexpr AllowVectorsTag allow_vectors{};

///
/// Create a read-only data view over a numpy array.
///
Expand All @@ -86,6 +92,38 @@ data_view(const pybind11::array_t<Eltype, pybind11::array::c_style>& data) {
);
}

///
/// Create a read-only data view over a numpy matrix or vector.
///
/// @tparam Eltype The element type of the array.
///
/// @param data The numpy array to alias.
/// @param property Indicate that it is okay to promote numpy vectors to matrices.
///
template <typename Eltype>
svs::data::ConstSimpleDataView<Eltype> data_view(
const pybind11::array_t<Eltype, pybind11::array::c_style>& data,
AllowVectorsTag SVS_UNUSED(property)
) {
size_t ndims = data.ndim();
// If this is a vector, interpret is a batch of queries with size 1.
// The type requirement `pybind11::array::c_style` means that the underlying data is
// contiguous, so we can construct a view from its pointer.
if (ndims == 1) {
return svs::data::ConstSimpleDataView<Eltype>(
data.template unchecked<1>().data(0), 1, data.shape(0)
);
}

if (ndims != 2) {
throw ANNEXCEPTION("This function can only accept numpy vectors or matrices.");
}

return svs::data::ConstSimpleDataView<Eltype>(
data.template unchecked<2>().data(0, 0), data.shape(0), data.shape(1)
);
}

///
/// Create a read-write MatrixView over a numpy array.
///
Expand Down
15 changes: 11 additions & 4 deletions bindings/python/src/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pybind11::tuple py_search(
pybind11::array_t<QueryType, pybind11::array::c_style> queries,
size_t n_neighbors
) {
const size_t n_queries = queries.shape(0);
const auto query_data = data_view(queries);
const auto query_data = data_view(queries, allow_vectors);
size_t n_queries = query_data.size();
auto result_idx = numpy_matrix<size_t>(n_queries, n_neighbors);
auto result_dists = numpy_matrix<float>(n_queries, n_neighbors);
svs::QueryResultView<size_t> q_result(
Expand All @@ -54,8 +54,12 @@ void add_search_specialization(pybind11::class_<Manager>& py_manager) {
Perform a search to return the `n_neighbors` approximate nearest neighbors to the query.
Args:
queries: Numpy Matrix representing the query batch. Individual queries are assumed to
the rows of the matrix. Returned results will have a position-wise correspondence
queries: Numpy Vector or Matrix representing the queries.
If the argument is a vector, it will be treated as a single query.
If the argument is a matrix, individual queries are assumed to the rows of the
matrix. Returned results will have a position-wise correspondence
with the queries. That is, the `N`-th row of the returned IDs and distances will
correspond to the `N`-th row in the query matrix.
Expand All @@ -64,6 +68,9 @@ Perform a search to return the `n_neighbors` approximate nearest neighbors to th
Returns:
A tuple `(I, D)` where `I` contains the `n_neighbors` approximate (or exact) nearest
neighbors to the queries and `D` contains the approximate distances.
Note: This form is returned regardless of whether the given query was a vector or a
matrix.
)"
);
}
Expand Down
51 changes: 50 additions & 1 deletion bindings/python/tests/test_vamana.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
import warnings

import numpy as np

from tempfile import TemporaryDirectory

import pysvs
Expand Down Expand Up @@ -79,12 +81,50 @@ def _setup(self, loader: pysvs.VectorDataLoader):
}),
]

# Ensure that passing 1-dimensional queries works and produces the same results as
# query batches.
def _test_single_query(
self,
vamana: pysvs.Vamana,
queries
):

I_full, D_full = vamana.search(queries, 10);

I_single = []
D_single = []
for i in range(queries.shape[0]):
query = queries[i, :]
self.assertTrue(query.ndim == 1)
I, D = vamana.search(query, 10)

self.assertTrue(I.ndim == 2)
self.assertTrue(D.ndim == 2)
self.assertTrue(I.shape == (1, 10))
self.assertTrue(D.shape == (1, 10))

I_single.append(I)
D_single.append(D)

I_single_concat = np.concatenate(I_single, axis = 0)
D_single_concat = np.concatenate(D_single, axis = 0)
self.assertTrue(np.array_equal(I_full, I_single_concat))
self.assertTrue(np.array_equal(D_full, D_single_concat))

# Throw an error on 3-dimensional inputs.
queries_3d = queries[:, :, np.newaxis]
with self.assertRaises(Exception) as context:
vamana.search(queries_3d, 10)

self.assertTrue("only accept numpy vectors or matrices" in str(context.exception))

def _test_basic_inner(
self,
vamana: pysvs.Vamana,
recall_dict,
num_threads: int,
skip_thread_test: bool = False,
test_single_query: bool = False,
):
# Make sure that the number of threads is propagated correctly.
self.assertEqual(vamana.num_threads, num_threads)
Expand Down Expand Up @@ -129,6 +169,9 @@ def _test_basic_inner(
if not DEBUG:
self.assertTrue(isapprox(recall, expected_recall, epsilon = 0.0005))

if test_single_query:
self._test_single_query(vamana, queries)

# Disable visited set.
self.visited_set_enabled = False

Expand Down Expand Up @@ -158,6 +201,7 @@ def _test_basic(self, loader, recall_dict):
self._test_basic_inner(vamana, recall_dict, num_threads)

# Test saving and reloading.
is_first = True
with TemporaryDirectory() as tempdir:
configdir = os.path.join(tempdir, "config")
graphdir = os.path.join(tempdir, "graph")
Expand All @@ -179,8 +223,13 @@ def _test_basic(self, loader, recall_dict):

reloaded.num_threads = num_threads
self._test_basic_inner(
reloaded, recall_dict, num_threads, skip_thread_test = True
reloaded,
recall_dict,
num_threads,
skip_thread_test = True,
test_single_query = is_first,
)
is_first = False

def test_basic(self):
# Load the index from files.
Expand Down

0 comments on commit b7e5488

Please sign in to comment.