From 9661814f972e91c32f0568e511b89544923105bd Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Sep 2023 19:55:35 +0200 Subject: [PATCH] Add `remove` function for IVF-Flat --- build.sh | 3 +- .../raft/neighbors/detail/ivf_flat_build.cuh | 15 ++++++ .../detail/ivf_flat_interleaved_scan-ext.cuh | 8 ++-- .../detail/ivf_flat_interleaved_scan-inl.cuh | 2 +- .../neighbors/detail/ivf_flat_search-ext.cuh | 6 +-- .../neighbors/detail/ivf_flat_search-inl.cuh | 2 +- .../neighbors/detail/ivf_flat_serialize.cuh | 8 ++-- .../raft/neighbors/detail/refine_device.cuh | 6 ++- cpp/include/raft/neighbors/ivf_flat-ext.cuh | 4 ++ cpp/include/raft/neighbors/ivf_flat-inl.cuh | 47 ++++++++++++++----- cpp/include/raft/neighbors/ivf_flat_types.hpp | 20 +++++++- .../raft/neighbors/sample_filter_types.hpp | 31 ++++++++++++ ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- cpp/src/neighbors/detail/ivf_flat_search.cu | 6 +-- 16 files changed, 131 insertions(+), 33 deletions(-) diff --git a/build.sh b/build.sh index 1fa1abbee5..7461b3ca27 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH" CACHE_ARGS="" @@ -324,6 +324,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 9cde1143e0..0b5e654920 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -76,6 +77,11 @@ auto clone(const raft::resources& res, const index& source) -> indexcenter_norms()->data_handle(), std::min(dim, 20)); } + index->resize_bitset(handle); } /** See raft::neighbors::ivf_flat::extend docs */ @@ -416,6 +423,7 @@ inline void fill_refinement_index(raft::resources const& handle, refinement_index->dim(), refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); + refinement_index->resize_bitset(handle); } template @@ -490,4 +498,11 @@ void unpack_list_data( RAFT_CUDA_TRY(cudaPeekAtLastError()); } +template +void remove(raft::resources const& handle, + index& index, + raft::device_vector_view indices) +{ + raft::util::bitset_set(handle, index.deleted_indices(), indices); +} } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 47f3e8888c..84231de3db 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -18,7 +18,7 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter +#include // deletion_ivf_filter #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view @@ -66,10 +66,10 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + float, float, int64_t, raft::neighbors::filtering::deletion_ivf_filter); instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + int8_t, int32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 81779668c4..d8e7cce914 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -740,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool valid = vec_id < list_length; // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) { + if (valid && sample_filter(queries_offset + blockIdx.y, list_id, vec_id)) { loadAndComputeDist lc(dist, compute_dist); for (int pos = 0; pos < shm_assisted_dim; diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index 976d15a61c..80a22b846a 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -55,10 +55,10 @@ void search(raft::resources const& handle, IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search( - float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + float, int64_t, raft::neighbors::filtering::deletion_ivf_filter); instantiate_raft_neighbors_ivf_flat_detail_search( - int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + int8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); instantiate_raft_neighbors_ivf_flat_detail_search( - uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + uint8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 93eeb0dead..b46824407b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,7 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter +#include // deletion_ivf_filter #include // utils::mapping #include // rmm::device_memory_resource diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index 61a6046273..7e2f424c37 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -34,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail { // backward compatibility. // TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward // compatible fashion. -constexpr int serialization_version = 4; +constexpr int serialization_version = 5; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -45,7 +45,7 @@ struct check_index_layout { "paste in the new size and consider updating the serialization logic"); }; -template struct check_index_layout), 328>; +template struct check_index_layout), 384>; /** * Save the index to file. @@ -99,6 +99,7 @@ void serialize(raft::resources const& handle, std::ostream& os, const index::roundUp(sizes_host(label))); } + serialize_mdspan(handle, os, index_.deleted_indices().to_mdspan()); resource::sync_stream(handle); } @@ -165,7 +166,8 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index::value_t, @@ -99,7 +100,8 @@ void refine_device(raft::resources const& handle, 1, k, raft::distance::is_min_close(metric), - raft::neighbors::filtering::none_ivf_sample_filter(), + raft::neighbors::filtering::deletion_ivf_filter( + const_index.deleted_indices(), refinement_index.inds_ptrs().data_handle()), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index 848703c9b5..7fde0b2820 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -114,6 +114,10 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) RAFT_EXPLICIT; +template +void remove(raft::resources const& handle, + index& index, + raft::device_vector_view indices) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index a18ee065bf..da0d04176c 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -407,6 +407,7 @@ void search_with_filtering(raft::resources const& handle, /** * @brief Search ANN using the constructed index using the given filter. * + * This will ignore the default filter used for indices deleted by the user. * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * * Note, this function requires a temporary buffer to store intermediate results between cuda kernel @@ -457,16 +458,18 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - raft::neighbors::ivf_flat::detail::search(handle, - params, - index, - queries, - n_queries, - k, - neighbors, - distances, - mr, - raft::neighbors::filtering::none_ivf_sample_filter()); + raft::neighbors::ivf_flat::detail::search( + handle, + params, + index, + queries, + n_queries, + k, + neighbors, + distances, + mr, + raft::neighbors::filtering::deletion_ivf_filter(index.deleted_indices(), + index.inds_ptrs().data_handle())); } /** @@ -477,6 +480,7 @@ void search(raft::resources const& handle, /** * @brief Search ANN using the constructed index using the given filter. * + * This will ignore the default filter used for indices deleted by the user. * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * * Note, this function requires a temporary buffer to store intermediate results between cuda kernel @@ -588,7 +592,28 @@ void search(raft::resources const& handle, queries, neighbors, distances, - raft::neighbors::filtering::none_ivf_sample_filter()); + raft::neighbors::filtering::deletion_ivf_filter( + index.deleted_indices(), index.inds_ptrs().data_handle())); +} + +/** + * @brief Remove the specified indices from the index. + * + * The specified indices to remove are marked as deleted in the index for the next searches, but + * the data is not removed. The index size is not changed. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * @param handle RAFT handle + * @param index ivf-flat constructed index + * @param indices list of indices to remove + */ +template +void remove(raft::resources const& handle, + index& index, + raft::device_vector_view indices) +{ + detail::remove(handle, index, indices); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 180fe2e21b..67dcb6cb0d 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -257,7 +258,8 @@ struct index : ann::index { list_sizes_{make_device_vector(res, n_lists)}, data_ptrs_{make_device_vector(res, n_lists)}, inds_ptrs_{make_device_vector(res, n_lists)}, - total_size_{0} + total_size_{0}, + deleted_indices_(res, 0) { check_consistency(); } @@ -298,6 +300,21 @@ struct index : ann::index { return conservative_memory_allocation_; } + /** Bitset view of deleted indices [size] */ + [[nodiscard]] inline auto deleted_indices() noexcept + -> raft::util::bitset_view + { + return deleted_indices_.view(); + } + [[nodiscard]] inline auto deleted_indices() const noexcept + -> raft::util::bitset_view + { + return deleted_indices_.view(); + } + + /** Resize the bitset to the current size of the index. */ + void resize_bitset(raft::resources const& res) { deleted_indices_.resize(res, size()); } + /** * Update the state of the dependent index members. */ @@ -362,6 +379,7 @@ struct index : ann::index { device_vector list_sizes_; device_matrix centers_; std::optional> center_norms_; + raft::util::bitset deleted_indices_; // Computed members device_vector data_ptrs_; diff --git a/cpp/include/raft/neighbors/sample_filter_types.hpp b/cpp/include/raft/neighbors/sample_filter_types.hpp index 5a301e9d2f..d4f33eb077 100644 --- a/cpp/include/raft/neighbors/sample_filter_types.hpp +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -20,6 +20,7 @@ #include #include +#include namespace raft::neighbors::filtering { @@ -37,6 +38,36 @@ struct none_ivf_sample_filter { } }; +/** + * @brief Filter an IVF index with a bitset + * + * @tparam index_t Indexing type + */ +template +struct deletion_ivf_filter { + // Pointers to the inverted lists (clusters) indices [n_lists] + const raft::util::bitset_view deleted_bitset_; + index_t* const* inds_ptr_ = nullptr; + + deletion_ivf_filter(const raft::util::bitset_view deleted_bitset, + index_t* const* inds_ptr) + : deleted_bitset_{deleted_bitset}, inds_ptr_{inds_ptr} + { + } + + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the current inverted list index + const uint32_t cluster_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + auto sample_idx = inds_ptr_[cluster_ix][sample_ix]; + return deleted_bitset_.test(sample_idx); + } +}; + /** * If the filtering depends on the index of a sample, then the following * filter template can be used: diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index a1d6cca7d5..ead138010e 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -37,6 +37,6 @@ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + float, float, int64_t, raft::neighbors::filtering::deletion_ivf_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 514301562d..bd0a6bfbbd 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -37,6 +37,6 @@ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + int8_t, int32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 32698a8e80..88d6a2f23d 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -37,6 +37,6 @@ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 9d39607750..c865b2a74f 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -31,10 +31,10 @@ IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search( - float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + float, int64_t, raft::neighbors::filtering::deletion_ivf_filter); instantiate_raft_neighbors_ivf_flat_detail_search( - int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + int8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); instantiate_raft_neighbors_ivf_flat_detail_search( - uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); + uint8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search