Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support vector deletion in ANN IVF #1831

Merged
merged 12 commits into from
Nov 6, 2023
4 changes: 3 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -324,6 +324,8 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
$CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ if(BUILD_PRIMS_BENCH)
bench/prims/neighbors/knn/brute_force_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_uint32_t.cu
bench/prims/neighbors/knn/cagra_float_uint32_t.cu
bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu
bench/prims/neighbors/refine_float_int64_t.cu
Expand Down
119 changes: 116 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@

#include <raft/random/rng.cuh>

#include <raft/core/bitset.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>

#include <thrust/sequence.h>

#include <optional>

namespace raft::bench::spatial {
Expand All @@ -44,11 +49,14 @@ struct params {
size_t n_queries;
/** Number of nearest neighbours to find for every probe. */
size_t k;
/** Ratio of removed indices. */
double removed_ratio;
};

inline auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k;
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#"
<< p.removed_ratio;
return os;
}

Expand Down Expand Up @@ -221,6 +229,104 @@ struct brute_force_knn {
}
};

template <typename ValT, typename IdxT>
struct ivf_flat_filter_knn {
using dist_t = float;

std::optional<const raft::neighbors::ivf_flat::index<ValT, IdxT>> index;
raft::neighbors::ivf_flat::index_params index_params;
raft::neighbors::ivf_flat::search_params search_params;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
params ps;

ivf_flat_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
}

void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto distance_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());

if (ps.removed_ratio > 0) {
raft::neighbors::ivf_flat::search_with_filtering(
handle, search_params, *index, queries_view, neighbors_view, distance_view, filter);
} else {
raft::neighbors::ivf_flat::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT>
struct ivf_pq_filter_knn {
using dist_t = float;

std::optional<const raft::neighbors::ivf_pq::index<IdxT>> index;
raft::neighbors::ivf_pq::index_params index_params;
raft::neighbors::ivf_pq::search_params search_params;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
params ps;

ivf_pq_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
}

void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, uint32_t>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view =
raft::make_device_matrix_view<IdxT, uint32_t>(out_idxs, ps.n_queries, ps.k);
auto distance_view =
raft::make_device_matrix_view<dist_t, uint32_t>(out_dists, ps.n_queries, ps.k);
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());

if (ps.removed_ratio > 0) {
raft::neighbors::ivf_pq::search_with_filtering(
handle, search_params, *index, queries_view, neighbors_view, distance_view, filter);
} else {
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT, typename ImplT>
struct knn : public fixture {
explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope)
Expand Down Expand Up @@ -378,8 +484,15 @@ struct knn : public fixture {
};

inline const std::vector<params> kInputs{
{2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}};

{2000000, 128, 1000, 32, 0}, {10000000, 128, 1000, 32, 0}, {10000, 8192, 1000, 32, 0}};

const std::vector<params> kInputsFilter =
raft::util::itertools::product<params>({size_t(10000000)}, // n_samples
{size_t(128)}, // n_dim
{size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inline const std::vector<TransferStrategy> kAllStrategies{
TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED};
inline const std::vector<TransferStrategy> kNoCopyOnly{TransferStrategy::NO_COPY};
Expand Down
24 changes: 24 additions & 0 deletions cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(float, int64_t, ivf_flat_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
24 changes: 24 additions & 0 deletions cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ inline auto build(raft::resources const& handle,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"unsupported data type");
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");

index<T, IdxT> index(handle, params, dim);
utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/distance/distance_types.hpp>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_rt_essentials.hpp> // RAFT_CUDA_TRY
#include <raft/util/device_loads_stores.cuh>
Expand Down Expand Up @@ -737,10 +738,11 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)

// This is the vector a given lane/thread handles
const uint32_t vec_id = group_id * WarpSize + lane_id;
const bool valid = vec_id < list_length;
const bool valid =
vec_id < list_length && sample_filter(queries_offset + blockIdx.y, list_id, vec_id);

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) {
if (valid) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
Expand Down Expand Up @@ -1096,22 +1098,25 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
rmm::cuda_stream_view stream)
{
const int capacity = bound_by_power_of_two(k);
select_interleaved_scan_kernel<T, AccT, IdxT, IvfSampleFilterT>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
queries_offset,
n_probes,
k,
sample_filter,
neighbors,
distances,
grid_dim_x,
stream);

auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter(
index.inds_ptrs().data_handle(), sample_filter);
select_interleaved_scan_kernel<T, AccT, IdxT, decltype(filter_adapter)>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
queries_offset,
n_probes,
k,
filter_adapter,
neighbors,
distances,
grid_dim_x,
stream);
}

} // namespace raft::neighbors::ivf_flat::detail
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,7 @@ auto build(raft::resources const& handle,
"Unsupported data type");

RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");

auto stream = resource::get_cuda_stream(handle);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,25 +180,38 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half, half, raft::neighbors::filtering::none_ivf_sample_filter);
half,
half,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, half, raft::neighbors::filtering::none_ivf_sample_filter);
float,
half,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, float, raft::neighbors::filtering::none_ivf_sample_filter);
float,
float,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);

#undef COMMA

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ inline void search(raft::resources const& handle,
rmm::device_uvector<float> rot_queries(max_queries * index.rot_dim(), stream, mr);
rmm::device_uvector<uint32_t> clusters_to_probe(max_queries * n_probes, stream, mr);

auto search_instance = ivfpq_search<IdxT, IvfSampleFilterT>::fun(params, index.metric());
auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter(
index.inds_ptrs().data_handle(), sample_filter);
auto search_instance = ivfpq_search<IdxT, decltype(filter_adapter)>::fun(params, index.metric());

for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);
Expand Down Expand Up @@ -850,7 +852,7 @@ inline void search(raft::resources const& handle,
distances + uint64_t(k) * (offset_q + offset_b),
utils::config<T>::kDivisor / utils::config<float>::kDivisor,
params.preferred_shmem_carveout,
sample_filter);
filter_adapter);
}
}
}
Expand Down
Loading