Skip to content

Commit

Permalink
Add remove function for IVF-Flat
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Sep 18, 2023
1 parent 8d30c9f commit 9661814
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 33 deletions.
3 changes: 2 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -324,6 +324,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
$CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \
Expand Down
15 changes: 15 additions & 0 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/neighbors/ivf_list_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/stats/histogram.cuh>
#include <raft/util/bitset.cuh>
#include <raft/util/pow2_utils.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -76,6 +77,11 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,

// Make sure the device pointers point to the new lists
target.recompute_internal_state(res);
target.resize_bitset(res);
copy(target.deleted_indices().data_handle(),
source.deleted_indices().data_handle(),
source.deleted_indices().n_elements(),
stream);

return target;
}
Expand Down Expand Up @@ -281,6 +287,7 @@ void extend(raft::resources const& handle,
stream);
RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}
index->resize_bitset(handle);
}

/** See raft::neighbors::ivf_flat::extend docs */
Expand Down Expand Up @@ -416,6 +423,7 @@ inline void fill_refinement_index(raft::resources const& handle,
refinement_index->dim(),
refinement_index->veclen());
RAFT_CUDA_TRY(cudaPeekAtLastError());
refinement_index->resize_bitset(handle);
}

template <typename T>
Expand Down Expand Up @@ -490,4 +498,11 @@ void unpack_list_data(
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename T, typename IdxT>
void remove(raft::resources const& handle,
index<T, IdxT>& index,
raft::device_vector_view<const IdxT, IdxT> indices)
{
raft::util::bitset_set(handle, index.deleted_indices(), indices);
}
} // namespace raft::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/neighbors/sample_filter_types.hpp> // none_ivf_sample_filter
#include <raft/neighbors/sample_filter_types.hpp> // deletion_ivf_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view

Expand Down Expand Up @@ -66,10 +66,10 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
float, float, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
int8_t, int32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const bool valid = vec_id < list_length;

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) {
if (valid && sample_filter(queries_offset + blockIdx.y, list_id, vec_id)) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ void search(raft::resources const& handle,
IvfSampleFilterT sample_filter)

instantiate_raft_neighbors_ivf_flat_detail_search(
float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
float, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);
instantiate_raft_neighbors_ivf_flat_detail_search(
int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
int8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);
instantiate_raft_neighbors_ivf_flat_detail_search(
uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
uint8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);

#undef instantiate_raft_neighbors_ivf_flat_detail_search
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <raft/matrix/detail/select_k.cuh> // matrix::detail::select_k
#include <raft/neighbors/detail/ivf_flat_interleaved_scan.cuh> // interleaved_scan
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/neighbors/sample_filter_types.hpp> // none_ivf_sample_filter
#include <raft/neighbors/sample_filter_types.hpp> // deletion_ivf_filter
#include <raft/spatial/knn/detail/ann_utils.cuh> // utils::mapping
#include <rmm/mr/device/per_device_resource.hpp> // rmm::device_memory_resource

Expand Down
8 changes: 5 additions & 3 deletions cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail {
// backward compatibility.
// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward
// compatible fashion.
constexpr int serialization_version = 4;
constexpr int serialization_version = 5;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand All @@ -45,7 +45,7 @@ struct check_index_layout {
"paste in the new size and consider updating the serialization logic");
};

template struct check_index_layout<sizeof(index<double, std::uint64_t>), 328>;
template struct check_index_layout<sizeof(index<double, std::uint64_t>), 384>;

/**
* Save the index to file.
Expand Down Expand Up @@ -99,6 +99,7 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T, I
list_store_spec,
Pow2<kIndexGroupSize>::roundUp(sizes_host(label)));
}
serialize_mdspan(handle, os, index_.deleted_indices().to_mdspan());
resource::sync_stream(handle);
}

Expand Down Expand Up @@ -165,7 +166,8 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index<T, Id
resource::sync_stream(handle);

index_.recompute_internal_state(handle);

index_.resize_bitset(handle);
deserialize_mdspan(handle, is, index_.deleted_indices().to_mdspan());
return index_;
}

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/neighbors/detail/refine_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ void refine_device(raft::resources const& handle,
neighbor_candidates.data_handle(),
n_queries,
n_candidates);
uint32_t grid_dim_x = 1;
uint32_t grid_dim_x = 1;
const auto& const_index = refinement_index;
raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<
data_t,
typename raft::spatial::knn::detail::utils::config<data_t>::value_t,
Expand All @@ -99,7 +100,8 @@ void refine_device(raft::resources const& handle,
1,
k,
raft::distance::is_min_close(metric),
raft::neighbors::filtering::none_ivf_sample_filter(),
raft::neighbors::filtering::deletion_ivf_filter(
const_index.deleted_indices(), refinement_index.inds_ptrs().data_handle()),
indices.data_handle(),
distances.data_handle(),
grid_dim_x,
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/raft/neighbors/ivf_flat-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ void search(raft::resources const& handle,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void remove(raft::resources const& handle,
index<T, IdxT>& index,
raft::device_vector_view<const IdxT, IdxT> indices) RAFT_EXPLICIT;
} // namespace raft::neighbors::ivf_flat

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand Down
47 changes: 36 additions & 11 deletions cpp/include/raft/neighbors/ivf_flat-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ void search_with_filtering(raft::resources const& handle,
/**
* @brief Search ANN using the constructed index using the given filter.
*
* This will ignore the default filter used for indices deleted by the user.
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
Expand Down Expand Up @@ -457,16 +458,18 @@ void search(raft::resources const& handle,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
{
raft::neighbors::ivf_flat::detail::search(handle,
params,
index,
queries,
n_queries,
k,
neighbors,
distances,
mr,
raft::neighbors::filtering::none_ivf_sample_filter());
raft::neighbors::ivf_flat::detail::search(
handle,
params,
index,
queries,
n_queries,
k,
neighbors,
distances,
mr,
raft::neighbors::filtering::deletion_ivf_filter(index.deleted_indices(),
index.inds_ptrs().data_handle()));
}

/**
Expand All @@ -477,6 +480,7 @@ void search(raft::resources const& handle,
/**
* @brief Search ANN using the constructed index using the given filter.
*
* This will ignore the default filter used for indices deleted by the user.
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
Expand Down Expand Up @@ -588,7 +592,28 @@ void search(raft::resources const& handle,
queries,
neighbors,
distances,
raft::neighbors::filtering::none_ivf_sample_filter());
raft::neighbors::filtering::deletion_ivf_filter(
index.deleted_indices(), index.inds_ptrs().data_handle()));
}

/**
* @brief Remove the specified indices from the index.
*
* The specified indices to remove are marked as deleted in the index for the next searches, but
* the data is not removed. The index size is not changed.
*
* @tparam T data element type
* @tparam IdxT type of the indices
* @param handle RAFT handle
* @param index ivf-flat constructed index
* @param indices list of indices to remove
*/
template <typename T, typename IdxT>
void remove(raft::resources const& handle,
index<T, IdxT>& index,
raft::device_vector_view<const IdxT, IdxT> indices)
{
detail::remove(handle, index, indices);
}

/** @} */
Expand Down
20 changes: 19 additions & 1 deletion cpp/include/raft/neighbors/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/ivf_list_types.hpp>
#include <raft/util/bitset.cuh>
#include <raft/util/integer_utils.hpp>

#include <thrust/reduce.h>
Expand Down Expand Up @@ -257,7 +258,8 @@ struct index : ann::index {
list_sizes_{make_device_vector<uint32_t, uint32_t>(res, n_lists)},
data_ptrs_{make_device_vector<T*, uint32_t>(res, n_lists)},
inds_ptrs_{make_device_vector<IdxT*, uint32_t>(res, n_lists)},
total_size_{0}
total_size_{0},
deleted_indices_(res, 0)
{
check_consistency();
}
Expand Down Expand Up @@ -298,6 +300,21 @@ struct index : ann::index {
return conservative_memory_allocation_;
}

/** Bitset view of deleted indices [size] */
[[nodiscard]] inline auto deleted_indices() noexcept
-> raft::util::bitset_view<std::uint32_t, IdxT>
{
return deleted_indices_.view();
}
[[nodiscard]] inline auto deleted_indices() const noexcept
-> raft::util::bitset_view<const std::uint32_t, IdxT>
{
return deleted_indices_.view();
}

/** Resize the bitset to the current size of the index. */
void resize_bitset(raft::resources const& res) { deleted_indices_.resize(res, size()); }

/**
* Update the state of the dependent index members.
*/
Expand Down Expand Up @@ -362,6 +379,7 @@ struct index : ann::index {
device_vector<uint32_t, uint32_t> list_sizes_;
device_matrix<float, uint32_t, row_major> centers_;
std::optional<device_vector<float, uint32_t>> center_norms_;
raft::util::bitset<std::uint32_t, IdxT> deleted_indices_;

// Computed members
device_vector<T*, uint32_t> data_ptrs_;
Expand Down
31 changes: 31 additions & 0 deletions cpp/include/raft/neighbors/sample_filter_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstdint>

#include <raft/core/detail/macros.hpp>
#include <raft/util/bitset.cuh>

namespace raft::neighbors::filtering {

Expand All @@ -37,6 +38,36 @@ struct none_ivf_sample_filter {
}
};

/**
* @brief Filter an IVF index with a bitset
*
* @tparam index_t Indexing type
*/
template <typename index_t>
struct deletion_ivf_filter {
// Pointers to the inverted lists (clusters) indices [n_lists]
const raft::util::bitset_view<const std::uint32_t, index_t> deleted_bitset_;
index_t* const* inds_ptr_ = nullptr;

deletion_ivf_filter(const raft::util::bitset_view<const std::uint32_t, index_t> deleted_bitset,
index_t* const* inds_ptr)
: deleted_bitset_{deleted_bitset}, inds_ptr_{inds_ptr}
{
}

inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the current inverted list index
const uint32_t cluster_ix,
// the index of the current sample inside the current inverted list
const uint32_t sample_ix) const
{
auto sample_idx = inds_ptr_[cluster_ix][sample_ix];
return deleted_bitset_.test(sample_idx);
}
};

/**
* If the filtering depends on the index of a sample, then the following
* filter template can be used:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
float, float, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
int8_t, int32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
6 changes: 3 additions & 3 deletions cpp/src/neighbors/detail/ivf_flat_search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
IvfSampleFilterT sample_filter)

instantiate_raft_neighbors_ivf_flat_detail_search(
float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
float, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);
instantiate_raft_neighbors_ivf_flat_detail_search(
int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
int8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);
instantiate_raft_neighbors_ivf_flat_detail_search(
uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
uint8_t, int64_t, raft::neighbors::filtering::deletion_ivf_filter<int64_t>);

#undef instantiate_raft_neighbors_ivf_flat_detail_search

0 comments on commit 9661814

Please sign in to comment.