diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b7184100..157d6903d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,8 +12,7 @@ # License for the specific language governing permissions and limitations under # the License -cmake_minimum_required(VERSION 3.23.0 FATAL_ERROR) -project(knowhere CXX C) +cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/") @@ -21,10 +20,20 @@ include(GNUInstallDirs) include(ExternalProject) include(cmake/utils/utils.cmake) +knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) +if (WITH_RAFT) + if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "") + set(CMAKE_CUDA_ARCHITECTURES RAPIDS) + endif() + include(cmake/libs/librapids.cmake) + project(knowhere CXX C CUDA) +else() + project(knowhere CXX C) +endif() + knowhere_option(WITH_UT "Build with UT test" OFF) knowhere_option(WITH_ASAN "Build with ASAN" OFF) knowhere_option(WITH_DISKANN "Build with diskann index" OFF) -knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) knowhere_option(WITH_BENCHMARK "Build with benchmark" OFF) knowhere_option(WITH_COVERAGE "Build with coverage" OFF) knowhere_option(WITH_CCACHE "Build with ccache" ON) @@ -64,18 +73,6 @@ endif() list( APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/) -if(WITH_RAFT) - if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "") - set(CMAKE_CUDA_ARCHITECTURES 86;80;75;70;61) - endif() - enable_language(CUDA) - find_package(CUDAToolkit REQUIRED) - if(${CUDAToolkit_VERSION_MAJOR} GREATER 10) - # cuda11 support --threads for compile some large .cu more efficient - add_compile_options($<$:--threads=4>) - endif() -endif() - add_definitions(-DNOT_COMPILE_FOR_SWIG) include(cmake/utils/compile_flags.cmake) @@ -113,8 +110,7 @@ if(WITH_COVERAGE) endif() knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc src/index/*.cc - src/io/*.cc src/index/*.cu src/common/raft/*.cu - src/common/raft/*.cc) + src/io/*.cc src/common/*.cu src/index/*.cu src/io/*.cu) set(KNOWHERE_LINKER_LIBS "") @@ -127,13 +123,13 @@ else() endif() knowhere_file_glob(GLOB_RECURSE KNOWHERE_GPU_SRCS src/index/gpu/flat_gpu/*.cc - src/index/gpu/ivf_gpu/*.cc src/index/cagra/*.cu) + src/index/gpu/ivf_gpu/*.cc) list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_GPU_SRCS}) if(NOT WITH_RAFT) - knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS src/index/ivf_raft/*.cc - src/index/ivf_raft/*.cu src/index/cagra/*.cu - src/common/raft/*.cu src/common/raft/*.cc) + knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS + src/common/raft/*.cu src/common/raft/*.cc + src/index/gpu_raft/*.cc) list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_RAFT_SRCS}) endif() @@ -150,7 +146,7 @@ list(APPEND KNOWHERE_LINKER_LIBS ${FOLLY_LIBRARIES}) add_library(knowhere SHARED ${KNOWHERE_SRCS}) add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS}) if(WITH_RAFT) - list(APPEND KNOWHERE_LINKER_LIBS raft::raft) + list(APPEND KNOWHERE_LINKER_LIBS raft::raft CUDA::cublas CUDA::cusparse CUDA::cusolver) endif() target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS}) target_include_directories( diff --git a/benchmark/hdf5/benchmark_float_qps.cpp b/benchmark/hdf5/benchmark_float_qps.cpp index 2022f4468..bc87e2a44 100644 --- a/benchmark/hdf5/benchmark_float_qps.cpp +++ b/benchmark/hdf5/benchmark_float_qps.cpp @@ -73,6 +73,60 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { } } + void + test_cagra(const knowhere::Json& cfg) { + auto conf = cfg; + + auto find_smallest_max_iters = [&](float expected_recall) -> int32_t { + auto ds_ptr = knowhere::GenDataSet(nq_, dim_, xq_); + auto left = 32; + auto right = 256; + auto max_iterations = left; + + float recall; + while (left <= right) { + max_iterations = left + (right - left) / 2; + conf[knowhere::indexparam::MAX_ITERATIONS] = max_iterations; + + auto result = index_.Search(*ds_ptr, conf, nullptr); + recall = CalcRecall(result.value()->GetIds(), nq_, topk_); + printf( + "[%0.3f s] iterate CAGRA param for recall %.4f: max_iterations=%d, k=%d, " + "R@=%.4f\n", + get_time_diff(), expected_recall, max_iterations, topk_, recall); + std::fflush(stdout); + if (std::abs(recall - expected_recall) <= 0.0001) { + return max_iterations; + } + if (recall < expected_recall) { + left = max_iterations + 1; + } else { + right = max_iterations - 1; + } + } + return left; + }; + + for (auto expected_recall : EXPECTED_RECALLs_) { + conf[knowhere::indexparam::ITOPK_SIZE] = ((int{topk_} + 32 - 1) / 32) * 32; + conf[knowhere::meta::TOPK] = topk_; + conf[knowhere::indexparam::MAX_ITERATIONS] = find_smallest_max_iters(expected_recall); + + printf( + "\n[%0.3f s] %s | %s | k=%d, " + "R@=%.4f\n", + get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(), topk_, expected_recall); + printf("================================================================================\n"); + for (auto thread_num : THREAD_NUMs_) { + CALC_TIME_SPAN(task(conf, thread_num, nq_)); + printf(" thread_num = %2d, elapse = %6.3fs, VPS = %.3f\n", thread_num, t_diff, nq_ / t_diff); + std::fflush(stdout); + } + printf("================================================================================\n"); + printf("[%.3f s] Test '%s/%s' done\n\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str()); + } + } + void test_hnsw(const knowhere::Json& cfg) { auto conf = cfg; @@ -221,6 +275,9 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { #ifdef KNOWHERE_WITH_GPU knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, 2); cfg_[knowhere::meta::DEVICE_ID] = GPU_DEVICE_ID; +#endif +#ifdef KNOWHERE_WITH_RAFT + knowhere::KnowhereConfig::SetRaftMemPool(); #endif } @@ -251,6 +308,9 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { // SCANN index params const std::vector SCANN_REORDER_K = {256, 512, 768, 1024}; const std::vector SCANN_WITH_RAW_DATA = {true}; + + // CAGRA index params + const std::vector GRAPH_DEGREE_ = {32, 64}; }; TEST_F(Benchmark_float_qps, TEST_IVF_FLAT) { @@ -344,3 +404,15 @@ TEST_F(Benchmark_float_qps, TEST_SCANN) { } } } +TEST_F(Benchmark_float_qps, TEST_CAGRA) { + index_type_ = knowhere::IndexEnum::INDEX_RAFT_CAGRA; + knowhere::Json conf = cfg_; + for (auto gd : GRAPH_DEGREE_) { + conf[knowhere::indexparam::GRAPH_DEGREE] = gd; + conf[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = gd; + conf[knowhere::indexparam::MAX_ITERATIONS] = 64; + std::string index_file_name = get_index_name({gd}); + create_index(index_file_name, conf); + test_cagra(conf); + } +} diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index 0505d87e2..909130381 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -81,7 +81,6 @@ class Benchmark_knowhere : public Benchmark_hdf5 { bin->size = dim_ * nb_ * sizeof(float); binary_set.Append("RAW_DATA", bin); } - index.Deserialize(binary_set, conf); } diff --git a/cmake/libs/libraft.cmake b/cmake/libs/libraft.cmake index 3de1cfaec..514793933 100644 --- a/cmake/libs/libraft.cmake +++ b/cmake/libs/libraft.cmake @@ -14,22 +14,15 @@ # the License. add_definitions(-DKNOWHERE_WITH_RAFT) -include(cmake/utils/fetch_rapids.cmake) -include(rapids-cmake) -include(rapids-cpm) -include(rapids-cuda) -include(rapids-export) -include(rapids-find) - -rapids_cpm_init() - -set(CMAKE_CUDA_FLAGS - "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") - -set(RAPIDS_VERSION 23.04) set(RAFT_VERSION "${RAPIDS_VERSION}") set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") +set(RAFT_PINNED_TAG "branch-23.12") + + +rapids_find_package(CUDAToolkit REQUIRED + BUILD_EXPORT_SET knowhere-exports + INSTALL_EXPORT_SET knowhere-exports +) function(find_and_configure_raft) set(oneValueArgs VERSION FORK PINNED_TAG) diff --git a/cmake/utils/fetch_rapids.cmake b/cmake/libs/librapids.cmake similarity index 71% rename from cmake/utils/fetch_rapids.cmake rename to cmake/libs/librapids.cmake index 56899f2c5..315c4c3d2 100644 --- a/cmake/utils/fetch_rapids.cmake +++ b/cmake/libs/librapids.cmake @@ -13,7 +13,7 @@ # License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION "23.04") +set(RAPIDS_VERSION 23.12) if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) file( @@ -22,3 +22,15 @@ if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) endif() include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + +include(rapids-cpm) # Dependency tracking +include(rapids-find) # Wrappers for finding packages +include(rapids-cuda) # Common CMake CUDA logic + +rapids_cuda_init_architectures(knowhere) +message(STATUS "INIT: ${CMAKE_CUDA_ARCHITECTURES}") + +rapids_cpm_init() + +set(CMAKE_CUDA_FLAGS + "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 675ccc35c..82f748145 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -81,6 +81,35 @@ constexpr const char* M = "m"; // PQ param for IVFPQ constexpr const char* SSIZE = "ssize"; constexpr const char* REORDER_K = "reorder_k"; constexpr const char* WITH_RAW_DATA = "with_raw_data"; +// RAFT Params +constexpr const char* REFINE_RATIO = "refine_ratio"; +// RAFT-specific IVF Params +constexpr const char* KMEANS_N_ITERS = "kmeans_n_iters"; +constexpr const char* KMEANS_TRAINSET_FRACTION = "kmeans_trainset_fraction"; +constexpr const char* ADAPTIVE_CENTERS = "adaptive_centers"; // IVF FLAT +constexpr const char* CODEBOOK_KIND = "codebook_kind"; // IVF PQ +constexpr const char* FORCE_RANDOM_ROTATION = "force_random_rotation"; // IVF PQ +constexpr const char* CONSERVATIVE_MEMORY_ALLOCATION = "conservative_memory_allocation"; // IVF PQ +constexpr const char* LUT_DTYPE = "lut_dtype"; // IVF PQ +constexpr const char* INTERNAL_DISTANCE_DTYPE = "internal_distance_dtype"; // IVF PQ +constexpr const char* PREFERRED_SHMEM_CARVEOUT = "preferred_shmem_carveout"; // IVF PQ + +// CAGRA Params +constexpr const char* INTERMEDIATE_GRAPH_DEGREE = "intermediate_graph_degree"; +constexpr const char* GRAPH_DEGREE = "graph_degree"; +constexpr const char* ITOPK_SIZE = "itopk_size"; +constexpr const char* MAX_QUERIES = "max_queries"; +constexpr const char* BUILD_ALGO = "build_algo"; +constexpr const char* SEARCH_ALGO = "search_algo"; +constexpr const char* TEAM_SIZE = "team_size"; +constexpr const char* SEARCH_WIDTH = "search_width"; +constexpr const char* MIN_ITERATIONS = "min_iterations"; +constexpr const char* MAX_ITERATIONS = "max_iterations"; +constexpr const char* THREAD_BLOCK_SIZE = "thread_block_size"; +constexpr const char* HASHMAP_MODE = "hashmap_mode"; +constexpr const char* HASHMAP_MIN_BITLEN = "hashmap_min_bitlen"; +constexpr const char* HASHMAP_MAX_FILL_RATE = "hashmap_max_fill_rate"; +constexpr const char* NN_DESCENT_NITER = "nn_descent_niter"; // HNSW Params constexpr const char* EFCONSTRUCTION = "efConstruction"; diff --git a/include/knowhere/comp/knowhere_config.h b/include/knowhere/comp/knowhere_config.h index de845b6cf..ba224cca7 100644 --- a/include/knowhere/comp/knowhere_config.h +++ b/include/knowhere/comp/knowhere_config.h @@ -107,6 +107,12 @@ class KnowhereConfig { */ static void SetRaftMemPool(size_t init_size, size_t max_size); + + /** + * Initialize RAFT with defaults + */ + static void + SetRaftMemPool(); }; } // namespace knowhere diff --git a/include/knowhere/device_bitset.h b/include/knowhere/device_bitset.h deleted file mode 100644 index 12532dc4e..000000000 --- a/include/knowhere/device_bitset.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (C) 2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef DEVICE_BITSET_H -#define DEVICE_BITSET_H - -#include "knowhere/bitsetview.h" -#include "raft/core/device_mdarray.hpp" -#include "raft/core/device_resources.hpp" -#include "raft/util/cudart_utils.hpp" - -namespace knowhere { - -struct DeviceBitsetView { - __device__ __host__ - DeviceBitsetView(const DeviceBitsetView& other) - : bits_{other.data()}, num_bits_{other.size()} { - } - __device__ __host__ - DeviceBitsetView(const uint8_t* data, size_t num_bits = size_t{}) - : bits_{data}, num_bits_{num_bits} { - } - - __device__ __host__ bool - empty() const { - return num_bits_ == 0; - } - - __device__ __host__ size_t - size() const { - return num_bits_; - } - - __device__ __host__ size_t - byte_size() const { - return (num_bits_ + 8 - 1) >> 3; - } - - __device__ __host__ const uint8_t* - data() const { - return bits_; - } - - __device__ bool - test(int64_t index) const { - auto result = false; - if (index < num_bits_) { - result = bits_[index >> 3] & (0x1 << (index & 0x7)); - } - return result; - } - - private: - const uint8_t* bits_ = nullptr; - size_t num_bits_ = 0; -}; - -struct DeviceBitset { - DeviceBitset(raft::device_resources& res, BitsetView const& other) - : storage_{[&res, &other]() { - auto result = raft::make_device_vector(res, other.byte_size()); - if (!other.empty()) { - raft::copy(result.data_handle(), other.data(), other.byte_size(), res.get_stream()); - } - return result; - }()}, - num_bits_{other.size()} { - } - - auto - view() { - return DeviceBitsetView{storage_.data_handle(), num_bits_}; - } - - private: - raft::device_vector storage_; - size_t num_bits_; -}; - -} // namespace knowhere - -#endif /* DEVICE_BITSET_H */ diff --git a/src/common/comp/knowhere_config.cc b/src/common/comp/knowhere_config.cc index 126529d21..50bb23b85 100644 --- a/src/common/comp/knowhere_config.cc +++ b/src/common/comp/knowhere_config.cc @@ -23,10 +23,10 @@ #ifdef KNOWHERE_WITH_GPU #include "index/gpu/gpu_res_mgr.h" #endif -#include "simd/hook.h" #ifdef KNOWHERE_WITH_RAFT -#include "common/raft/raft_utils.h" +#include "common/raft/integration/raft_initialization.hpp" #endif +#include "simd/hook.h" namespace knowhere { @@ -163,7 +163,20 @@ KnowhereConfig::FreeGPUResource() { void KnowhereConfig::SetRaftMemPool(size_t init_size, size_t max_size) { #ifdef KNOWHERE_WITH_RAFT - raft_utils::set_mem_pool_size(init_size, max_size); + auto config = raft_knowhere::raft_configuration{}; + config.init_mem_pool_size_mb = init_size; + config.max_mem_pool_size_mb = max_size; + // This should probably be a separate configuration option, but fine for now + config.max_workspace_size_mb = max_size; + raft_knowhere::initialize_raft(config); +#endif +} +void +KnowhereConfig::SetRaftMemPool() { + // Overload for default values +#ifdef KNOWHERE_WITH_RAFT + auto config = raft_knowhere::raft_configuration{}; + raft_knowhere::initialize_raft(config); #endif } diff --git a/src/common/raft/integration/cagra_index.cu b/src/common/raft/integration/cagra_index.cu new file mode 100644 index 000000000..4879f5c39 --- /dev/null +++ b/src/common/raft/integration/cagra_index.cu @@ -0,0 +1,21 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_index.cuh" +namespace raft_knowhere { +template struct raft_knowhere_index; +} // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_flat_index.cu b/src/common/raft/integration/ivf_flat_index.cu new file mode 100644 index 000000000..fc759075d --- /dev/null +++ b/src/common/raft/integration/ivf_flat_index.cu @@ -0,0 +1,21 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_index.cuh" +namespace raft_knowhere { +template struct raft_knowhere_index; +} // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_pq_index.cu b/src/common/raft/integration/ivf_pq_index.cu new file mode 100644 index 000000000..9284a0930 --- /dev/null +++ b/src/common/raft/integration/ivf_pq_index.cu @@ -0,0 +1,21 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_index.cuh" +namespace raft_knowhere { +template struct raft_knowhere_index; +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_initialization.cc b/src/common/raft/integration/raft_initialization.cc new file mode 100644 index 000000000..59bb6f842 --- /dev/null +++ b/src/common/raft/integration/raft_initialization.cc @@ -0,0 +1,79 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/integration/raft_initialization.hpp" + +#include + +#include +#include +#include +#include +namespace raft_knowhere { + +void +initialize_raft(raft_configuration const& config) { + auto static initialization_flag = std::once_flag{}; + std::call_once(initialization_flag, [&config]() { + raft::device_resources_manager::set_streams_per_device(config.streams_per_device); + if (config.stream_pool_size) { + raft::device_resources_manager::set_stream_pools_per_device(config.stream_pools_per_device, + *(config.stream_pool_size)); + } else { + raft::device_resources_manager::set_stream_pools_per_device(config.stream_pools_per_device); + } + if (config.init_mem_pool_size_mb) { + raft::device_resources_manager::set_init_mem_pool_size(*(config.init_mem_pool_size_mb) << 20); + } + if (config.max_mem_pool_size_mb) { + if (*config.max_mem_pool_size_mb > 0) { + raft::device_resources_manager::set_max_mem_pool_size(*(config.max_mem_pool_size_mb) << 20); + } + } else { + raft::device_resources_manager::set_max_mem_pool_size(std::nullopt); + } + if (config.max_workspace_size_mb) { + raft::device_resources_manager::set_workspace_allocation_limit(*(config.max_workspace_size_mb) << 20); + } + auto device_count = []() { + auto result = 0; + RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); + RAFT_EXPECTS(result != 0, "No CUDA devices found"); + return result; + }(); + + for (auto device_id = 0; device_id < device_count; ++device_id) { + auto scoped_device = raft::device_setter{device_id}; + auto workspace_size = std::size_t{}; + if (config.max_workspace_size_mb) { + workspace_size = *(config.max_workspace_size_mb) << 20; + } else { + auto free_mem = std::size_t{}; + auto total_mem = std::size_t{}; + RAFT_CUDA_TRY_NO_THROW(cudaMemGetInfo(&free_mem, &total_mem)); + // Heuristic: If workspace size is not explicitly specified, use half of free memory or a quarter of + // total memory, whichever is larger + workspace_size = std::max(free_mem / std::size_t{2}, total_mem / std::size_t{4}); + } + if (workspace_size > std::size_t{}) { + raft::device_resources_manager::set_workspace_memory_resource( + raft::resource::workspace_resource_factory::default_pool_resource(workspace_size), device_id); + } + } + }); +} + +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_initialization.hpp b/src/common/raft/integration/raft_initialization.hpp new file mode 100644 index 000000000..a26f48e4c --- /dev/null +++ b/src/common/raft/integration/raft_initialization.hpp @@ -0,0 +1,31 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +namespace raft_knowhere { +struct raft_configuration { + std::size_t streams_per_device = std::size_t{16}; + std::size_t stream_pools_per_device = std::size_t{}; + std::optional stream_pool_size = std::nullopt; + std::optional init_mem_pool_size_mb = std::nullopt; + std::optional max_mem_pool_size_mb = std::nullopt; + std::optional max_workspace_size_mb = std::nullopt; +}; + +void +initialize_raft(raft_configuration const& config); +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_config.hpp b/src/common/raft/integration/raft_knowhere_config.hpp new file mode 100644 index 000000000..e29e452ef --- /dev/null +++ b/src/common/raft/integration/raft_knowhere_config.hpp @@ -0,0 +1,120 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +#include "common/raft/proto/raft_index_kind.hpp" + +namespace raft_knowhere { +// This struct includes all parameters that may be passed to underlying RAFT +// indexes. It is designed to not expose ANY RAFT types in order to cleanly +// separate RAFT from knowhere headers. +struct raft_knowhere_config { + raft_proto::raft_index_kind index_type; + int k = 10; + + // Common Parameters + std::string metric_type = std::string{"L2"}; + float metric_arg = 2.0f; + bool add_data_on_build = true; + float refine_ratio = 1.0f; + + // Shared IVF Parameters + std::optional nlist = std::nullopt; + std::optional nprobe = std::nullopt; + std::optional kmeans_n_iters = std::nullopt; + std::optional kmeans_trainset_fraction = std::nullopt; + + // IVF Flat only Parameters + std::optional adaptive_centers = std::nullopt; + + // IVFPQ only Parameters + std::optional m = std::nullopt; + std::optional nbits = std::nullopt; + std::optional codebook_kind = std::nullopt; + std::optional force_random_rotation = std::nullopt; + std::optional conservative_memory_allocation = std::nullopt; + std::optional lookup_table_dtype = std::nullopt; + std::optional internal_distance_dtype = std::nullopt; + std::optional preferred_shmem_carveout = std::nullopt; + + // CAGRA Parameters + std::optional intermediate_graph_degree = std::nullopt; + std::optional graph_degree = std::nullopt; + std::optional itopk_size = std::nullopt; + std::optional max_queries = std::nullopt; + std::optional build_algo = std::nullopt; + std::optional search_algo = std::nullopt; + std::optional team_size = std::nullopt; + std::optional search_width = std::nullopt; + std::optional min_iterations = std::nullopt; + std::optional max_iterations = std::nullopt; + std::optional thread_block_size = std::nullopt; + std::optional hashmap_mode = std::nullopt; + std::optional hashmap_min_bitlen = std::nullopt; + std::optional hashmap_max_fill_rate = std::nullopt; + std::optional nn_descent_niter = std::nullopt; +}; + +// The following function provides a single source of truth for default values +// of RAFT index configurations. +[[nodiscard]] inline auto +validate_raft_knowhere_config(raft_knowhere_config config) { + if (config.index_type == raft_proto::raft_index_kind::ivf_flat || + config.index_type == raft_proto::raft_index_kind::ivf_pq) { + config.add_data_on_build = true; + config.nlist = config.nlist.value_or(128); + config.nprobe = config.nprobe.value_or(8); + config.kmeans_n_iters = config.kmeans_n_iters.value_or(20); + config.kmeans_trainset_fraction = config.kmeans_trainset_fraction.value_or(0.5f); + } + if (config.index_type == raft_proto::raft_index_kind::ivf_flat) { + config.adaptive_centers = config.adaptive_centers.value_or(false); + } + if (config.index_type == raft_proto::raft_index_kind::ivf_pq) { + config.m = config.m.value_or(0); + config.nbits = config.nbits.value_or(8); + config.codebook_kind = config.codebook_kind.value_or("PER_SUBSPACE"); + config.force_random_rotation = config.force_random_rotation.value_or(false); + config.conservative_memory_allocation = config.conservative_memory_allocation.value_or(false); + config.lookup_table_dtype = config.lookup_table_dtype.value_or("CUDA_R_32F"); + config.internal_distance_dtype = config.internal_distance_dtype.value_or("CUDA_R_32F"); + config.preferred_shmem_carveout = config.preferred_shmem_carveout.value_or(1.0f); + } + if (config.index_type == raft_proto::raft_index_kind::cagra) { + config.add_data_on_build = true; + config.intermediate_graph_degree = config.intermediate_graph_degree.value_or(128); + config.graph_degree = config.graph_degree.value_or(64); + config.itopk_size = config.itopk_size.value_or(64); + config.max_queries = config.max_queries.value_or(0); + config.build_algo = config.build_algo.value_or("IVF_PQ"); + config.search_algo = config.search_algo.value_or("AUTO"); + config.team_size = config.team_size.value_or(0); + config.search_width = config.search_width.value_or(1); + config.min_iterations = config.min_iterations.value_or(0); + config.max_iterations = config.max_iterations.value_or(0); + config.thread_block_size = config.thread_block_size.value_or(0); + config.hashmap_mode = config.hashmap_mode.value_or("AUTO"); + config.hashmap_min_bitlen = config.hashmap_min_bitlen.value_or(0); + config.hashmap_max_fill_rate = config.hashmap_max_fill_rate.value_or(0.5f); + config.nn_descent_niter = config.nn_descent_niter.value_or(20); + } + return config; +} + +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_index.cuh b/src/common/raft/integration/raft_knowhere_index.cuh new file mode 100644 index 000000000..805c72730 --- /dev/null +++ b/src/common/raft/integration/raft_knowhere_index.cuh @@ -0,0 +1,620 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/proto/raft_index.cuh" +#include "common/raft/integration/raft_knowhere_index.hpp" + +namespace raft_knowhere { +namespace detail { + +// This helper struct maps the generic type of RAFT index to the specific +// instantiation of that index used within knowhere. +template +struct raft_index_type_mapper : std::false_type { +}; + +template <> +struct raft_index_type_mapper : std::true_type { + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using type = raft_proto::raft_index; + using underlying_index_type = typename type::vector_index_type; + using index_params_type = typename type::index_params_type; + using search_params_type = typename type::search_params_type; +}; +template <> +struct raft_index_type_mapper : std::true_type { + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using type = raft_proto::raft_index; + using underlying_index_type = typename type::vector_index_type; + using index_params_type = typename type::index_params_type; + using search_params_type = typename type::search_params_type; +}; +template <> +struct raft_index_type_mapper : std::true_type { + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using type = raft_proto::raft_index; + using underlying_index_type = typename type::vector_index_type; + using index_params_type = typename type::index_params_type; + using search_params_type = typename type::search_params_type; +}; + +} // namespace detail + +template +using raft_index_t = typename detail::raft_index_type_mapper::type; + +template +using raft_index_params_t = typename detail::raft_index_type_mapper::index_params_type; +template +using raft_search_params_t = typename detail::raft_index_type_mapper::search_params_type; + +// Metrics are passed between knowhere and RAFT as strings to avoid tight +// coupling between the implementation details of either one. +[[nodiscard]] inline auto metric_string_to_raft_distance_type(std::string const& metric_string) { + auto result = raft::distance::DistanceType::L2Expanded; + if (metric_string == "L2") { + result = raft::distance::DistanceType::L2Expanded; + } else if (metric_string == "L2SqrtExpanded") { + result = raft::distance::DistanceType::L2SqrtExpanded; + } else if (metric_string == "CosineExpanded") { + result = raft::distance::DistanceType::CosineExpanded; + } else if (metric_string == "L1") { + result = raft::distance::DistanceType::L1; + } else if (metric_string == "L2Unexpanded") { + result = raft::distance::DistanceType::L2Unexpanded; + } else if (metric_string == "L2SqrtUnexpanded") { + result = raft::distance::DistanceType::L2SqrtUnexpanded; + } else if (metric_string == "IP") { + result = raft::distance::DistanceType::InnerProduct; + } else if (metric_string == "Linf") { + result = raft::distance::DistanceType::Linf; + } else if (metric_string == "Canberra") { + result = raft::distance::DistanceType::Canberra; + } else if (metric_string == "LpUnexpanded") { + result = raft::distance::DistanceType::LpUnexpanded; + } else if (metric_string == "CorrelationExpanded") { + result = raft::distance::DistanceType::CorrelationExpanded; + } else if (metric_string == "JACCARD") { + result = raft::distance::DistanceType::JaccardExpanded; + } else if (metric_string == "HellingerExpanded") { + result = raft::distance::DistanceType::HellingerExpanded; + } else if (metric_string == "Haversine") { + result = raft::distance::DistanceType::Haversine; + } else if (metric_string == "BrayCurtis") { + result = raft::distance::DistanceType::BrayCurtis; + } else if (metric_string == "JensenShannon") { + result = raft::distance::DistanceType::JensenShannon; + } else if (metric_string == "HAMMING") { + result = raft::distance::DistanceType::HammingUnexpanded; + } else if (metric_string == "KLDivergence") { + result = raft::distance::DistanceType::KLDivergence; + } else if (metric_string == "RusselRaoExpanded") { + result = raft::distance::DistanceType::RusselRaoExpanded; + } else if (metric_string == "DiceExpanded") { + result = raft::distance::DistanceType::DiceExpanded; + } else if (metric_string == "Precomputed") { + result = raft::distance::DistanceType::Precomputed; + } else { + RAFT_FAIL("Unrecognized metric type %s", metric_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto codebook_string_to_raft_codebook_gen(std::string const& codebook_string) { + auto result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + if (codebook_string == "PER_SUBSPACE") { + result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + } else if (codebook_string == "PER_CLUSTER") { + result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + } else { + RAFT_FAIL("Unrecognized codebook type %s", codebook_string.c_str()); + } + return result; +} +[[nodiscard]] inline auto build_algo_string_to_cagra_build_algo(std::string + const& algo_string) { + auto result = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + if (algo_string == "IVF_PQ") { + result = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + } else if (algo_string == "NN_DESCENT") { + result = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } else { + RAFT_FAIL("Unrecognized CAGRA build algo %s", algo_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto search_algo_string_to_cagra_search_algo(std::string + const& algo_string) { + auto result = raft::neighbors::cagra::search_algo::AUTO; + if (algo_string == "SINGLE_CTA") { + result = raft::neighbors::cagra::search_algo::SINGLE_CTA; + } else if (algo_string == "MULTI_CTA") { + result = raft::neighbors::cagra::search_algo::MULTI_CTA; + } else if (algo_string == "MULTI_KERNEL") { + result = raft::neighbors::cagra::search_algo::MULTI_KERNEL; + } else if (algo_string == "AUTO") { + result = raft::neighbors::cagra::search_algo::AUTO; + } else { + RAFT_FAIL("Unrecognized CAGRA search algo %s", algo_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto hashmap_mode_string_to_cagra_hashmap_mode(std::string + const& mode_string) { + auto result = raft::neighbors::cagra::hash_mode::AUTO; + if (mode_string == "HASH") { + result = raft::neighbors::cagra::hash_mode::HASH; + } else if (mode_string == "SMALL") { + result = raft::neighbors::cagra::hash_mode::SMALL; + } else if (mode_string == "AUTO") { + result = raft::neighbors::cagra::hash_mode::AUTO; + } else { + RAFT_FAIL("Unrecognized CAGRA hash mode %s", mode_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto dtype_string_to_cuda_dtype(std::string + const& dtype_string) { + auto result = CUDA_R_32F; + if (dtype_string == "CUDA_R_16F") { + result = CUDA_R_16F; + } else if (dtype_string == "CUDA_C_16F") { + result = CUDA_C_16F; + } else if (dtype_string == "CUDA_R_16BF") { + result = CUDA_R_16BF; + } else if (dtype_string == "CUDA_R_32F") { + result = CUDA_R_32F; + } else if (dtype_string == "CUDA_C_32F") { + result = CUDA_C_32F; + } else if (dtype_string == "CUDA_R_64F") { + result = CUDA_R_64F; + } else if (dtype_string == "CUDA_C_64F") { + result = CUDA_C_64F; + } else if (dtype_string == "CUDA_R_8I") { + result = CUDA_R_8I; + } else if (dtype_string == "CUDA_C_8I") { + result = CUDA_C_8I; + } else if (dtype_string == "CUDA_R_8U") { + result = CUDA_R_8U; + } else if (dtype_string == "CUDA_C_8U") { + result = CUDA_C_8U; + } else if (dtype_string == "CUDA_R_32I") { + result = CUDA_R_32I; + } else if (dtype_string == "CUDA_C_32I") { + result = CUDA_C_32I; + } else if (dtype_string == "CUDA_R_8F_E4M3") { + result = CUDA_R_8F_E4M3; + } else if (dtype_string == "CUDA_R_8F_E5M2") { + result = CUDA_R_8F_E5M2; + } else { + RAFT_FAIL("Unrecognized dtype %s", dtype_string.c_str()); + } + return result; +} + +// Given a generic config without RAFT symbols, convert to RAFT index build +// parameters +template +[[nodiscard]] auto config_to_index_params(raft_knowhere_config const& raw_config) { + RAFT_EXPECTS(raw_config.index_type == IndexKind, "Incorrect index type for this index"); + auto config = validate_raft_knowhere_config(raw_config); + auto result = raft_index_params_t{}; + + result.metric = metric_string_to_raft_distance_type(config.metric_type); + result.metric_arg = config.metric_arg; + result.add_data_on_build = config.add_data_on_build; + + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat || IndexKind == + raft_proto::raft_index_kind::ivf_pq) { + result.n_lists = *(config.nlist); + result.kmeans_n_iters = *(config.kmeans_n_iters); + result.kmeans_trainset_fraction = *(config.kmeans_trainset_fraction); + result.conservative_memory_allocation = *(config.conservative_memory_allocation); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat) { + result.adaptive_centers = *(config.adaptive_centers); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_pq) { + result.pq_dim = *(config.m); + result.pq_bits = *(config.nbits); + result.codebook_kind = codebook_string_to_raft_codebook_gen(*(config.codebook_kind)); + result.force_random_rotation = *(config.force_random_rotation); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::cagra) { + result.intermediate_graph_degree = *(config.intermediate_graph_degree); + result.graph_degree = *(config.graph_degree); + result.build_algo = build_algo_string_to_cagra_build_algo(*(config.build_algo)); + result.nn_descent_niter = *(config.nn_descent_niter); + } + return result; +} + +// Given a generic config without RAFT symbols, convert to RAFT index search +// parameters +template +[[nodiscard]] auto config_to_search_params(raft_knowhere_config const& raw_config) { + RAFT_EXPECTS(raw_config.index_type == IndexKind, "Incorrect index type for this index"); + auto config = validate_raft_knowhere_config(raw_config); + auto result = raft_search_params_t{}; + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat || IndexKind + == raft_proto::raft_index_kind::ivf_pq) { + result.n_probes = *(config.nprobe); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_pq) { + result.lut_dtype = dtype_string_to_cuda_dtype(*(config.lookup_table_dtype)); + result.internal_distance_dtype = dtype_string_to_cuda_dtype(*(config.internal_distance_dtype)); + result.preferred_shmem_carveout = *(config.preferred_shmem_carveout); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::cagra) { + result.max_queries = *(config.max_queries); + result.itopk_size = *(config.itopk_size); + result.max_iterations = *(config.max_iterations); + result.algo = search_algo_string_to_cagra_search_algo(*(config.search_algo)); + result.team_size = *(config.team_size); + result.search_width = *(config.search_width); + result.min_iterations = *(config.min_iterations); + result.thread_block_size = *(config.thread_block_size); + result.hashmap_mode = hashmap_mode_string_to_cagra_hashmap_mode(*(config.hashmap_mode)); + result.hashmap_min_bitlen = *(config.hashmap_min_bitlen); + result.hashmap_max_fill_rate = *(config.hashmap_max_fill_rate); + } + return result; +} + +inline auto select_device_id() { + auto static device_count = []() { + auto result = 0; + RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); + RAFT_EXPECTS(result != 0, "No CUDA devices found"); + return result; + }(); + auto static index_counter = std::atomic{0}; + // Use round-robin assignment to distribute indexes across devices + auto result = index_counter.fetch_add(1) % device_count; + return result; +} + +// This struct is used to connect knowhere to a RAFT index. The implementation +// is provided here, but this header should never be directly included in +// another knowhere header. This ensures that RAFT symbols are not exposed in +// any knowhere header. +template +struct raft_knowhere_index::impl { + auto static constexpr index_kind = IndexKind; + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using input_indexing_type = raft_input_indexing_t; + + impl() {} + + private: + using raft_index_type = raft_index_t; + + public: + auto is_trained() const { + return index_.has_value(); + } + [[nodiscard]] auto size() const { + auto result = std::int64_t{}; + if (is_trained()) { + result = index_->size(); + } + return result; + } + [[nodiscard]] auto dim() const { + auto result = std::int64_t{}; + if (is_trained()) { + result = index_->dim(); + } + return result; + } + + void train(raft_knowhere_config const& config, data_type const* data, + knowhere_indexing_type row_count, knowhere_indexing_type feature_count) { + auto scoped_device = raft::device_setter{device_id}; + auto index_params = config_to_index_params(config); + auto const& res = raft::device_resources_manager::get_device_resources(); + auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); + device_dataset_storage = raft::make_device_matrix(res, row_count, feature_count); + auto device_data = device_dataset_storage->view(); + raft::copy(res, device_data, host_data); + index_ = raft_index_type::template build(res, index_params, raft::make_const_mdspan(device_data)); + } + + void add(data_type const* data, knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, knowhere_indexing_type const* new_ids) { + if constexpr (index_kind == raft_proto::raft_index_kind::cagra) { + if (index_) { + RAFT_FAIL("CAGRA does not support adding vectors after training"); + } + } else if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq){ + if (index_) { + RAFT_FAIL("IVFPQ does not support adding vectors after training"); + } + } else { + if (index_) { + auto const& res = raft::device_resources_manager::get_device_resources(); + raft::resource::set_workspace_to_pool_resource(res); + auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); + device_dataset_storage = raft::make_device_matrix(res, row_count, feature_count); + auto device_data = device_dataset_storage->view(); + raft::copy(res, device_data, host_data); + auto device_ids_storage = std::optional>{}; + if (new_ids != nullptr) { + auto host_ids = raft::make_host_vector_view(new_ids, row_count); + device_ids_storage = raft::make_device_vector(res, row_count); + raft::copy(res, device_ids_storage->view(), host_ids); + } + + if (device_ids_storage) { + index_ = raft_index_type::extend( + res, + raft::make_const_mdspan(device_data), + std::make_optional(raft::make_const_mdspan(device_ids_storage->view())), + *index_ + ); + } else { + index_ = raft_index_type::extend( + res, + raft::make_const_mdspan(device_data), + std::optional>{}, + *index_ + ); + } + } else { + RAFT_FAIL("Index has not yet been trained"); + } + } + } + + auto search( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_bitset_data_type const * bitset_data, + knowhere_bitset_indexing_type bitset_byte_size, + knowhere_bitset_indexing_type bitset_size + ) const { + auto scoped_device = raft::device_setter{device_id}; + auto const& res = raft::device_resources_manager::get_device_resources(); + auto k = knowhere_indexing_type(config.k); + auto search_params = config_to_search_params(config); + + auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); + auto device_data_storage = raft::make_device_matrix(res, row_count, feature_count); + raft::copy(res, device_data_storage.view(), host_data); + + auto device_bitset = std::optional>{}; + + if(bitset_data != nullptr && bitset_byte_size != 0) { + device_bitset = raft::core::bitset(res, bitset_size); + raft::copy(res, device_bitset->to_mdspan(), raft::make_host_vector_view(bitset_data, bitset_byte_size)); + device_bitset->flip(res); + } + + auto output_size = row_count * k; + auto ids = std::unique_ptr(new knowhere_indexing_type[output_size]); + auto distances = std::unique_ptr(new knowhere_data_type[output_size]); + + auto host_ids = raft::make_host_matrix_view(ids.get(), row_count, k); + auto host_distances = raft::make_host_matrix_view(distances.get(), row_count, k); + + auto device_ids_storage = raft::make_device_matrix(res, row_count, k); + auto device_distances_storage = raft::make_device_matrix(res, row_count, k); + auto device_ids = device_ids_storage.view(); + auto device_distances = device_distances_storage.view(); + + RAFT_EXPECTS(index_, "Index has not yet been trained"); + auto dataset_view = device_dataset_storage ? + std::make_optional(device_dataset_storage->view()) : + std::optional>{}; + + + if (device_bitset) { + raft_index_type::search( + res, + *index_, + search_params, + raft::make_const_mdspan(device_data_storage.view()), + device_ids, + device_distances, + config.refine_ratio, + input_indexing_type{}, + dataset_view, + raft::neighbors::filtering::bitset_filter{ + device_bitset->view() + } + ); + } else { + raft_index_type::search( + res, + *index_, + search_params, + raft::make_const_mdspan(device_data_storage.view()), + device_ids, + device_distances, + config.refine_ratio, + input_indexing_type{}, + dataset_view + ); + } + if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq) { + thrust::replace( + res.get_thrust_policy(), + thrust::device_ptr(device_ids.data_handle()), + thrust::device_ptr(device_ids.data_handle() + output_size), + raft::neighbors::ivf_pq::kOutOfBoundsRecord, + indexing_type{-1} + ); + } + raft::copy(res, host_ids, device_ids); + raft::copy(res, host_distances, device_distances); + return std::make_tuple(ids.release(), distances.release()); + } + void range_search() const { + RAFT_FAIL("Range search not yet implemented for RAFT indexes"); + } + void get_vector_by_id() const { + RAFT_FAIL("Vector reconstruction not yet implemented for RAFT indexes"); + } + void serialize( + std::ostream& os + ) const { + auto scoped_device = raft::device_setter{device_id}; + auto const& res = raft::device_resources_manager::get_device_resources(); + RAFT_EXPECTS(index_, "Index has not yet been trained"); + raft_index_type::template serialize(res, os, *index_); + } + auto static deserialize( + std::istream& is + ) { + auto new_device_id = select_device_id(); + auto scoped_device = raft::device_setter{new_device_id}; + auto const& res = raft::device_resources_manager::get_device_resources(); + return std::make_unique::impl>( + raft_index_type::template deserialize(res, is), + new_device_id + ); + } + void synchronize() const { + auto scoped_device = raft::device_setter{device_id}; + raft::device_resources_manager::get_device_resources().sync_stream(); + } + impl(raft_index_type&& index, int new_device_id) : index_{std::move(index)}, + device_id{new_device_id} {} + + private: + std::optional index_ = std::nullopt; + int device_id = select_device_id(); + std::optional> device_dataset_storage = std::nullopt; +}; + +template +raft_knowhere_index::raft_knowhere_index() : pimpl{new + raft_knowhere_index::impl()} {} + +template +raft_knowhere_index::~raft_knowhere_index() = default; + +template +raft_knowhere_index::raft_knowhere_index(raft_knowhere_index&& other) : pimpl{std::move(other.pimpl)} {} + +template +raft_knowhere_index& raft_knowhere_index::operator=(raft_knowhere_index&& other) { + pimpl = std::move(other.pimpl); + return *this; +} + +template +bool raft_knowhere_index::is_trained() const { + return pimpl->is_trained(); +} + +template +std::int64_t raft_knowhere_index::size() const { + return pimpl->size(); +} + +template +std::int64_t raft_knowhere_index::dim() const { + return pimpl->dim(); +} + +template +void raft_knowhere_index::train( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count +) { + return pimpl->train(config, data, row_count, feature_count); +} +template +void raft_knowhere_index::add( + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_indexing_type const* new_ids +) { + return pimpl->add(data, row_count, feature_count, new_ids); +} +template +std::tuple raft_knowhere_index::search( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_bitset_data_type const* bitset_data, + knowhere_bitset_indexing_type bitset_byte_size, + knowhere_bitset_indexing_type bitset_size +) const { + return pimpl->search(config, data, row_count, feature_count, bitset_data, bitset_byte_size, bitset_size); +} + +template +void raft_knowhere_index::range_search() const { + return pimpl->range_search(); +} + +template +void raft_knowhere_index::get_vector_by_id() const { + return pimpl->get_vector_by_id(); +} + +template +void raft_knowhere_index::serialize(std::ostream& os) const { + return pimpl->serialize(os); +} + +template +raft_knowhere_index +raft_knowhere_index::deserialize(std::istream& is) { + return raft_knowhere_index(raft_knowhere_index::impl::deserialize(is)); +} + +template +void raft_knowhere_index::synchronize() const { + return pimpl->synchronize(); +} + +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_index.hpp b/src/common/raft/integration/raft_knowhere_index.hpp new file mode 100644 index 000000000..f26a7f0c7 --- /dev/null +++ b/src/common/raft/integration/raft_knowhere_index.hpp @@ -0,0 +1,124 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +#include "common/raft/integration/raft_knowhere_config.hpp" +#include "common/raft/proto/raft_index_kind.hpp" + +namespace raft_knowhere { + +using knowhere_data_type = float; +using knowhere_indexing_type = std::int64_t; +using knowhere_bitset_data_type = std::uint8_t; +using knowhere_bitset_indexing_type = std::uint32_t; + +namespace detail { + +template +struct raft_io_type_mapper : std::false_type {}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::int64_t; + using input_indexing_type = std::int64_t; +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::int64_t; + using input_indexing_type = std::int64_t; +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::uint32_t; + using input_indexing_type = std::int64_t; +}; + +} // namespace detail + +template +using raft_data_t = typename detail::raft_io_type_mapper::data_type; + +template +using raft_indexing_t = typename detail::raft_io_type_mapper::indexing_type; + +template +using raft_input_indexing_t = typename detail::raft_io_type_mapper::input_indexing_type; + +template +struct raft_knowhere_index { + auto static constexpr index_kind = IndexKind; + + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using input_indexing_type = raft_input_indexing_t; + + raft_knowhere_index(); + ~raft_knowhere_index(); + + raft_knowhere_index(raft_knowhere_index&& other); + raft_knowhere_index& + operator=(raft_knowhere_index&& other); + + bool + is_trained() const; + std::int64_t + size() const; + std::int64_t + dim() const; + + void + train(raft_knowhere_config const&, data_type const*, knowhere_indexing_type, knowhere_indexing_type); + void + add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count, + knowhere_indexing_type const* new_ids = nullptr); + std::tuple + search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data = nullptr, + knowhere_bitset_indexing_type bitset_byte_size = knowhere_bitset_indexing_type{}, + knowhere_bitset_indexing_type bitset_size = knowhere_bitset_indexing_type{}) const; + void + range_search() const; + void + get_vector_by_id() const; + void + serialize(std::ostream& os) const; + static raft_knowhere_index + deserialize(std::istream& is); + void + synchronize() const; + + private: + // Use a private implementation to completely separate knowhere headers from + // RAFT headers + struct impl; + std::unique_ptr pimpl; + + raft_knowhere_index(std::unique_ptr&& new_pimpl) : pimpl{std::move(new_pimpl)} { + } +}; + +extern template struct raft_knowhere_index; +extern template struct raft_knowhere_index; +extern template struct raft_knowhere_index; + +} // namespace raft_knowhere diff --git a/src/common/raft/proto/raft_index.cuh b/src/common/raft/proto/raft_index.cuh new file mode 100644 index 000000000..ff25b7e04 --- /dev/null +++ b/src/common/raft/proto/raft_index.cuh @@ -0,0 +1,583 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/raft/proto/raft_index_kind.hpp" + +namespace raft_proto { + +auto static const RAFT_NAME = raft::RAFT_NAME; + +namespace detail { +template typename index_template> +struct template_matches_index_kind : std::false_type{}; + +template<> +struct template_matches_index_kind : std::true_type{}; + +template<> +struct template_matches_index_kind : std::true_type{}; + +template<> +struct template_matches_index_kind : std::true_type{}; + +template typename index_template> +auto static constexpr template_matches_index_kind_v = template_matches_index_kind::value; + +} + +template