From 5483b82d469c55e288ee01906ad15fb85f5c8edf Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 3 Nov 2023 20:06:41 -0400 Subject: [PATCH 1/4] Add CAGRA support with latest RAFT Update to RAFT 23.12 Update CAGRA integration to improve performance Avoid post-filtering using RAFT's new filtering feature Use RAFT's new device_resources_manager to simplify and optimize resource initialization Update build infratructure to build for all supported CUDA architectures Refactor RAFT integration code to more cleanly separate RAFT code from Knowhere code Avoid exposing RAFT symbols in any Knowhere header Signed-off-by: William Hicks --- CMakeLists.txt | 40 +- benchmark/hdf5/benchmark_float_qps.cpp | 81 ++- benchmark/hdf5/benchmark_knowhere.h | 3 +- cmake/libs/libraft.cmake | 21 +- .../librapids.cmake} | 14 +- include/knowhere/comp/index_param.h | 29 + include/knowhere/comp/knowhere_config.h | 6 + include/knowhere/device_bitset.h | 90 --- src/common/comp/knowhere_config.cc | 19 +- src/common/raft/integration/cagra_index.cu | 21 + src/common/raft/integration/ivf_flat_index.cu | 21 + src/common/raft/integration/ivf_pq_index.cu | 21 + .../raft/integration/raft_initialization.cc | 85 +++ .../raft/integration/raft_initialization.hpp | 30 + .../raft/integration/raft_knowhere_config.hpp | 118 ++++ .../raft/integration/raft_knowhere_index.cuh | 620 +++++++++++++++++ .../raft/integration/raft_knowhere_index.hpp | 125 ++++ .../raft/proto/ivf_to_sample_filter.cuh | 39 ++ src/common/raft/proto/raft_index.cuh | 562 ++++++++++++++++ src/common/raft/proto/raft_index_kind.hpp | 25 + src/common/raft/raft.cu | 39 -- src/common/raft/raft_utils.cc | 46 -- src/common/raft/raft_utils.h | 200 ------ src/index/cagra/cagra.cu | 210 ------ src/index/cagra/cagra_config.h | 47 -- src/index/gpu_raft/gpu_raft.h | 292 ++++++++ src/index/gpu_raft/gpu_raft_cagra.cc | 32 + src/index/gpu_raft/gpu_raft_cagra_config.h | 144 ++++ .../gpu_raft_ivf_flat.cc} | 23 +- src/index/gpu_raft/gpu_raft_ivf_flat_config.h | 75 +++ src/index/gpu_raft/gpu_raft_ivf_pq.cc | 32 + .../gpu_raft_ivf_pq_config.h} | 102 +-- src/index/ivf_raft/ivf_raft.cuh | 631 ------------------ tests/ut/test_gpu_search.cc | 49 +- 34 files changed, 2502 insertions(+), 1390 deletions(-) rename cmake/{utils/fetch_rapids.cmake => libs/librapids.cmake} (71%) delete mode 100644 include/knowhere/device_bitset.h create mode 100644 src/common/raft/integration/cagra_index.cu create mode 100644 src/common/raft/integration/ivf_flat_index.cu create mode 100644 src/common/raft/integration/ivf_pq_index.cu create mode 100644 src/common/raft/integration/raft_initialization.cc create mode 100644 src/common/raft/integration/raft_initialization.hpp create mode 100644 src/common/raft/integration/raft_knowhere_config.hpp create mode 100644 src/common/raft/integration/raft_knowhere_index.cuh create mode 100644 src/common/raft/integration/raft_knowhere_index.hpp create mode 100644 src/common/raft/proto/ivf_to_sample_filter.cuh create mode 100644 src/common/raft/proto/raft_index.cuh create mode 100644 src/common/raft/proto/raft_index_kind.hpp delete mode 100644 src/common/raft/raft.cu delete mode 100644 src/common/raft/raft_utils.cc delete mode 100644 src/common/raft/raft_utils.h delete mode 100644 src/index/cagra/cagra.cu delete mode 100644 src/index/cagra/cagra_config.h create mode 100644 src/index/gpu_raft/gpu_raft.h create mode 100644 src/index/gpu_raft/gpu_raft_cagra.cc create mode 100644 src/index/gpu_raft/gpu_raft_cagra_config.h rename src/index/{ivf_raft/ivf_raft.cu => gpu_raft/gpu_raft_ivf_flat.cc} (50%) create mode 100644 src/index/gpu_raft/gpu_raft_ivf_flat_config.h create mode 100644 src/index/gpu_raft/gpu_raft_ivf_pq.cc rename src/index/{ivf_raft/ivf_raft_config.h => gpu_raft/gpu_raft_ivf_pq_config.h} (55%) delete mode 100644 src/index/ivf_raft/ivf_raft.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index ace217aa5..e0de3e452 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,8 +12,7 @@ # License for the specific language governing permissions and limitations under # the License -cmake_minimum_required(VERSION 3.23.0 FATAL_ERROR) -project(knowhere CXX C) +cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/") @@ -21,10 +20,20 @@ include(GNUInstallDirs) include(ExternalProject) include(cmake/utils/utils.cmake) +knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) +if (WITH_RAFT) + if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "") + set(CMAKE_CUDA_ARCHITECTURES RAPIDS) + endif() + include(cmake/libs/librapids.cmake) + project(knowhere CXX C CUDA) +else() + project(knowhere CXX C) +endif() + knowhere_option(WITH_UT "Build with UT test" OFF) knowhere_option(WITH_ASAN "Build with ASAN" OFF) knowhere_option(WITH_DISKANN "Build with diskann index" OFF) -knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) knowhere_option(WITH_BENCHMARK "Build with benchmark" OFF) knowhere_option(WITH_COVERAGE "Build with coverage" OFF) knowhere_option(WITH_CCACHE "Build with ccache" ON) @@ -49,18 +58,6 @@ endif() list( APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/) -if(WITH_RAFT) - if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "") - set(CMAKE_CUDA_ARCHITECTURES 86;80;75;70;61) - endif() - enable_language(CUDA) - find_package(CUDAToolkit REQUIRED) - if(${CUDAToolkit_VERSION_MAJOR} GREATER 10) - # cuda11 support --threads for compile some large .cu more efficient - add_compile_options($<$:--threads=4>) - endif() -endif() - add_definitions(-DNOT_COMPILE_FOR_SWIG) include(cmake/utils/compile_flags.cmake) @@ -99,8 +96,7 @@ if(WITH_COVERAGE) endif() knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc src/index/*.cc - src/io/*.cc src/index/*.cu src/common/raft/*.cu - src/common/raft/*.cc) + src/io/*.cc src/common/*.cu src/index/*.cu src/io/*.cu) set(KNOWHERE_LINKER_LIBS "") @@ -113,13 +109,13 @@ else() endif() knowhere_file_glob(GLOB_RECURSE KNOWHERE_GPU_SRCS src/index/gpu/flat_gpu/*.cc - src/index/gpu/ivf_gpu/*.cc src/index/cagra/*.cu) + src/index/gpu/ivf_gpu/*.cc) list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_GPU_SRCS}) if(NOT WITH_RAFT) - knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS src/index/ivf_raft/*.cc - src/index/ivf_raft/*.cu src/index/cagra/*.cu - src/common/raft/*.cu src/common/raft/*.cc) + knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS + src/common/raft/*.cu src/common/raft/*.cc + src/index/gpu_raft/*.cc) list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_RAFT_SRCS}) endif() @@ -135,7 +131,7 @@ list(APPEND KNOWHERE_LINKER_LIBS ${FOLLY_LIBRARIES}) add_library(knowhere SHARED ${KNOWHERE_SRCS}) add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS}) if(WITH_RAFT) - list(APPEND KNOWHERE_LINKER_LIBS raft::raft) + list(APPEND KNOWHERE_LINKER_LIBS raft::raft CUDA::cublas CUDA::cusparse CUDA::cusolver) endif() target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS}) target_include_directories( diff --git a/benchmark/hdf5/benchmark_float_qps.cpp b/benchmark/hdf5/benchmark_float_qps.cpp index 2022f4468..244ef3e9b 100644 --- a/benchmark/hdf5/benchmark_float_qps.cpp +++ b/benchmark/hdf5/benchmark_float_qps.cpp @@ -13,6 +13,7 @@ #include #include +#include #include "benchmark_knowhere.h" #include "knowhere/comp/index_param.h" @@ -73,6 +74,61 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { } } + void + test_cagra(const knowhere::Json& cfg) { + auto conf = cfg; + + auto find_smallest_max_iters = [&](float expected_recall) -> int32_t { + auto ds_ptr = knowhere::GenDataSet(nq_, dim_, xq_); + auto left = 32; + auto right = 256; + auto max_iterations = left; + + float recall; + while (left <= right) { + max_iterations = left + (right - left) / 2; + conf[knowhere::indexparam::MAX_ITERATIONS] = max_iterations; + + auto result = index_.Search(*ds_ptr, conf, nullptr); + recall = CalcRecall(result.value()->GetIds(), nq_, topk_); + printf( + "[%0.3f s] iterate CAGRA param for recall %.4f: max_iterations=%d, k=%d, " + "R@=%.4f\n", + get_time_diff(), expected_recall, max_iterations, topk_, recall); + std::fflush(stdout); + if (std::abs(recall - expected_recall) <= 0.0001) { + return max_iterations; + } + if (recall < expected_recall) { + left = max_iterations + 1; + } else { + right = max_iterations - 1; + } + } + return left; + }; + + for (auto expected_recall : EXPECTED_RECALLs_) { + conf[knowhere::indexparam::ITOPK_SIZE] = ((int{topk_} + 32 - 1) / 32) * 32; + conf[knowhere::meta::TOPK] = topk_; + conf[knowhere::indexparam::MAX_ITERATIONS] = find_smallest_max_iters(expected_recall); + + printf( + "\n[%0.3f s] %s | %s | k=%d, " + "R@=%.4f\n", + get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(), topk_, + expected_recall); + printf("================================================================================\n"); + for (auto thread_num : THREAD_NUMs_) { + CALC_TIME_SPAN(task(conf, thread_num, nq_)); + printf(" thread_num = %2d, elapse = %6.3fs, VPS = %.3f\n", thread_num, t_diff, nq_ / t_diff); + std::fflush(stdout); + } + printf("================================================================================\n"); + printf("[%.3f s] Test '%s/%s' done\n\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str()); + } + } + void test_hnsw(const knowhere::Json& cfg) { auto conf = cfg; @@ -183,10 +239,12 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { private: void task(const knowhere::Json& conf, int32_t worker_num, int32_t nq_total) { + NVTX3_FUNC_RANGE(); auto worker = [&](int32_t idx_start, int32_t num) { num = std::min(num, nq_total - idx_start); for (int32_t i = 0; i < num; i++) { knowhere::DataSetPtr ds_ptr = knowhere::GenDataSet(1, dim_, (const float*)xq_ + (idx_start + i) * dim_); + auto loop_range = nvtx3::scoped_range{"loop range"}; index_.Search(*ds_ptr, conf, nullptr); } }; @@ -221,6 +279,10 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { #ifdef KNOWHERE_WITH_GPU knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, 2); cfg_[knowhere::meta::DEVICE_ID] = GPU_DEVICE_ID; +#endif +#ifdef KNOWHERE_WITH_RAFT + // knowhere::KnowhereConfig::SetRaftMemPool(24576, 36864); + knowhere::KnowhereConfig::SetRaftMemPool(); #endif } @@ -251,6 +313,9 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { // SCANN index params const std::vector SCANN_REORDER_K = {256, 512, 768, 1024}; const std::vector SCANN_WITH_RAW_DATA = {true}; + + // CAGRA index params + const std::vector GRAPH_DEGREE_ = {32, 64}; }; TEST_F(Benchmark_float_qps, TEST_IVF_FLAT) { @@ -271,7 +336,7 @@ TEST_F(Benchmark_float_qps, TEST_IVF_FLAT) { } } -TEST_F(Benchmark_float_qps, TEST_IVF_SQ8) { +/* TEST_F(Benchmark_float_qps, TEST_IVF_SQ8) { #ifdef KNOWHERE_WITH_GPU index_type_ = knowhere::IndexEnum::INDEX_FAISS_GPU_IVFSQ8; #else @@ -285,7 +350,7 @@ TEST_F(Benchmark_float_qps, TEST_IVF_SQ8) { create_index(index_file_name, conf); test_ivf(conf); } -} +} */ TEST_F(Benchmark_float_qps, TEST_IVF_PQ) { #ifdef KNOWHERE_WITH_GPU @@ -344,3 +409,15 @@ TEST_F(Benchmark_float_qps, TEST_SCANN) { } } } +TEST_F(Benchmark_float_qps, TEST_CAGRA) { + index_type_ = knowhere::IndexEnum::INDEX_RAFT_CAGRA; + knowhere::Json conf = cfg_; + for (auto gd : GRAPH_DEGREE_) { + conf[knowhere::indexparam::GRAPH_DEGREE] = gd; + conf[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = gd; + conf[knowhere::indexparam::MAX_ITERATIONS] = 64; + std::string index_file_name = get_index_name({gd}); + create_index(index_file_name, conf); + test_cagra(conf); + } +} diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index 0505d87e2..640e57f8b 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -77,11 +77,10 @@ class Benchmark_knowhere : public Benchmark_hdf5 { // IVFFLAT_NM should load raw data if (index_type_ == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT && binary_set.GetByName("RAW_DATA") == nullptr) { knowhere::BinaryPtr bin = std::make_shared(); - bin->data = std::shared_ptr((uint8_t*)xb_); + bin->data = std::shared_ptr((uint8_t*)xb_, [](const uint8_t[]) {}); bin->size = dim_ * nb_ * sizeof(float); binary_set.Append("RAW_DATA", bin); } - index.Deserialize(binary_set, conf); } diff --git a/cmake/libs/libraft.cmake b/cmake/libs/libraft.cmake index 3de1cfaec..4cdeca973 100644 --- a/cmake/libs/libraft.cmake +++ b/cmake/libs/libraft.cmake @@ -14,22 +14,15 @@ # the License. add_definitions(-DKNOWHERE_WITH_RAFT) -include(cmake/utils/fetch_rapids.cmake) -include(rapids-cmake) -include(rapids-cpm) -include(rapids-cuda) -include(rapids-export) -include(rapids-find) - -rapids_cpm_init() +set(RAFT_VERSION "${RAPIDS_VERSION}") +set(RAFT_FORK "wphicks") +set(RAFT_PINNED_TAG "bug-ivf_flat_filter") -set(CMAKE_CUDA_FLAGS - "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") -set(RAPIDS_VERSION 23.04) -set(RAFT_VERSION "${RAPIDS_VERSION}") -set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") +rapids_find_package(CUDAToolkit REQUIRED + BUILD_EXPORT_SET knowhere-exports + INSTALL_EXPORT_SET knowhere-exports +) function(find_and_configure_raft) set(oneValueArgs VERSION FORK PINNED_TAG) diff --git a/cmake/utils/fetch_rapids.cmake b/cmake/libs/librapids.cmake similarity index 71% rename from cmake/utils/fetch_rapids.cmake rename to cmake/libs/librapids.cmake index 56899f2c5..315c4c3d2 100644 --- a/cmake/utils/fetch_rapids.cmake +++ b/cmake/libs/librapids.cmake @@ -13,7 +13,7 @@ # License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION "23.04") +set(RAPIDS_VERSION 23.12) if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) file( @@ -22,3 +22,15 @@ if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) endif() include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + +include(rapids-cpm) # Dependency tracking +include(rapids-find) # Wrappers for finding packages +include(rapids-cuda) # Common CMake CUDA logic + +rapids_cuda_init_architectures(knowhere) +message(STATUS "INIT: ${CMAKE_CUDA_ARCHITECTURES}") + +rapids_cpm_init() + +set(CMAKE_CUDA_FLAGS + "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 986d07241..be1cbf0ec 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -80,6 +80,35 @@ constexpr const char* M = "m"; // PQ param for IVFPQ constexpr const char* SSIZE = "ssize"; constexpr const char* REORDER_K = "reorder_k"; constexpr const char* WITH_RAW_DATA = "with_raw_data"; +// RAFT Params +constexpr const char* REFINE_RATIO = "refine_ratio"; +// RAFT-specific IVF Params +constexpr const char* KMEANS_N_ITERS = "kmeans_n_iters"; +constexpr const char* KMEANS_TRAINSET_FRACTION = "kmeans_trainset_fraction"; +constexpr const char* ADAPTIVE_CENTERS = "adaptive_centers"; // IVF FLAT +constexpr const char* CODEBOOK_KIND = "codebook_kind"; // IVF PQ +constexpr const char* FORCE_RANDOM_ROTATION = "force_random_rotation"; // IVF PQ +constexpr const char* CONSERVATIVE_MEMORY_ALLOCATION = "conservative_memory_allocation"; // IVF PQ +constexpr const char* LUT_DTYPE = "lut_dtype"; // IVF PQ +constexpr const char* INTERNAL_DISTANCE_DTYPE = "internal_distance_dtype"; // IVF PQ +constexpr const char* PREFERRED_SHMEM_CARVEOUT = "preferred_shmem_carveout"; // IVF PQ + +// CAGRA Params +constexpr const char* INTERMEDIATE_GRAPH_DEGREE = "intermediate_graph_degree"; +constexpr const char* GRAPH_DEGREE = "graph_degree"; +constexpr const char* ITOPK_SIZE = "itopk_size"; +constexpr const char* MAX_QUERIES = "max_queries"; +constexpr const char* BUILD_ALGO = "build_algo"; +constexpr const char* SEARCH_ALGO = "search_algo"; +constexpr const char* TEAM_SIZE = "team_size"; +constexpr const char* SEARCH_WIDTH = "search_width"; +constexpr const char* MIN_ITERATIONS = "min_iterations"; +constexpr const char* MAX_ITERATIONS = "max_iterations"; +constexpr const char* THREAD_BLOCK_SIZE = "thread_block_size"; +constexpr const char* HASHMAP_MODE = "hashmap_mode"; +constexpr const char* HASHMAP_MIN_BITLEN = "hashmap_min_bitlen"; +constexpr const char* HASHMAP_MAX_FILL_RATE = "hashmap_max_fill_rate"; +constexpr const char* NN_DESCENT_NITER = "nn_descent_niter"; // HNSW Params constexpr const char* EFCONSTRUCTION = "efConstruction"; diff --git a/include/knowhere/comp/knowhere_config.h b/include/knowhere/comp/knowhere_config.h index de845b6cf..ba224cca7 100644 --- a/include/knowhere/comp/knowhere_config.h +++ b/include/knowhere/comp/knowhere_config.h @@ -107,6 +107,12 @@ class KnowhereConfig { */ static void SetRaftMemPool(size_t init_size, size_t max_size); + + /** + * Initialize RAFT with defaults + */ + static void + SetRaftMemPool(); }; } // namespace knowhere diff --git a/include/knowhere/device_bitset.h b/include/knowhere/device_bitset.h deleted file mode 100644 index 12532dc4e..000000000 --- a/include/knowhere/device_bitset.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (C) 2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef DEVICE_BITSET_H -#define DEVICE_BITSET_H - -#include "knowhere/bitsetview.h" -#include "raft/core/device_mdarray.hpp" -#include "raft/core/device_resources.hpp" -#include "raft/util/cudart_utils.hpp" - -namespace knowhere { - -struct DeviceBitsetView { - __device__ __host__ - DeviceBitsetView(const DeviceBitsetView& other) - : bits_{other.data()}, num_bits_{other.size()} { - } - __device__ __host__ - DeviceBitsetView(const uint8_t* data, size_t num_bits = size_t{}) - : bits_{data}, num_bits_{num_bits} { - } - - __device__ __host__ bool - empty() const { - return num_bits_ == 0; - } - - __device__ __host__ size_t - size() const { - return num_bits_; - } - - __device__ __host__ size_t - byte_size() const { - return (num_bits_ + 8 - 1) >> 3; - } - - __device__ __host__ const uint8_t* - data() const { - return bits_; - } - - __device__ bool - test(int64_t index) const { - auto result = false; - if (index < num_bits_) { - result = bits_[index >> 3] & (0x1 << (index & 0x7)); - } - return result; - } - - private: - const uint8_t* bits_ = nullptr; - size_t num_bits_ = 0; -}; - -struct DeviceBitset { - DeviceBitset(raft::device_resources& res, BitsetView const& other) - : storage_{[&res, &other]() { - auto result = raft::make_device_vector(res, other.byte_size()); - if (!other.empty()) { - raft::copy(result.data_handle(), other.data(), other.byte_size(), res.get_stream()); - } - return result; - }()}, - num_bits_{other.size()} { - } - - auto - view() { - return DeviceBitsetView{storage_.data_handle(), num_bits_}; - } - - private: - raft::device_vector storage_; - size_t num_bits_; -}; - -} // namespace knowhere - -#endif /* DEVICE_BITSET_H */ diff --git a/src/common/comp/knowhere_config.cc b/src/common/comp/knowhere_config.cc index 126529d21..2596407f5 100644 --- a/src/common/comp/knowhere_config.cc +++ b/src/common/comp/knowhere_config.cc @@ -23,10 +23,10 @@ #ifdef KNOWHERE_WITH_GPU #include "index/gpu/gpu_res_mgr.h" #endif -#include "simd/hook.h" #ifdef KNOWHERE_WITH_RAFT -#include "common/raft/raft_utils.h" +#include "common/raft/integration/raft_initialization.hpp" #endif +#include "simd/hook.h" namespace knowhere { @@ -163,7 +163,20 @@ KnowhereConfig::FreeGPUResource() { void KnowhereConfig::SetRaftMemPool(size_t init_size, size_t max_size) { #ifdef KNOWHERE_WITH_RAFT - raft_utils::set_mem_pool_size(init_size, max_size); + auto config = raft_knowhere::raft_configuration{}; + config.init_mem_pool_size_mb = init_size; + config.max_mem_pool_size_mb = max_size; + // This should probably be a separate configuration option, but fine for now + config.max_workspace_size_mb = max_size; + raft_knowhere::initialize_raft(config); +#endif +} +void +KnowhereConfig::SetRaftMemPool() { + // Overload for default values +#ifdef KNOWHERE_WITH_RAFT + auto config = raft_knowhere::raft_configuration{}; + raft_knowhere::initialize_raft(config); #endif } diff --git a/src/common/raft/integration/cagra_index.cu b/src/common/raft/integration/cagra_index.cu new file mode 100644 index 000000000..4879f5c39 --- /dev/null +++ b/src/common/raft/integration/cagra_index.cu @@ -0,0 +1,21 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_index.cuh" +namespace raft_knowhere { +template struct raft_knowhere_index; +} // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_flat_index.cu b/src/common/raft/integration/ivf_flat_index.cu new file mode 100644 index 000000000..fc759075d --- /dev/null +++ b/src/common/raft/integration/ivf_flat_index.cu @@ -0,0 +1,21 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_index.cuh" +namespace raft_knowhere { +template struct raft_knowhere_index; +} // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_pq_index.cu b/src/common/raft/integration/ivf_pq_index.cu new file mode 100644 index 000000000..9284a0930 --- /dev/null +++ b/src/common/raft/integration/ivf_pq_index.cu @@ -0,0 +1,21 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_index.cuh" +namespace raft_knowhere { +template struct raft_knowhere_index; +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_initialization.cc b/src/common/raft/integration/raft_initialization.cc new file mode 100644 index 000000000..591eecca0 --- /dev/null +++ b/src/common/raft/integration/raft_initialization.cc @@ -0,0 +1,85 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "common/raft/integration/raft_initialization.hpp" +namespace raft_knowhere { + +void initialize_raft(raft_configuration const& config) { + auto static initialization_flag = std::once_flag{}; + std::call_once( + initialization_flag, + [&config]() { + raft::device_resources_manager::set_streams_per_device(config.streams_per_device); + if (config.stream_pool_size) { + raft::device_resources_manager::set_stream_pools_per_device( + config.stream_pools_per_device, + *(config.stream_pool_size) + ); + } else { + raft::device_resources_manager::set_stream_pools_per_device(config.stream_pools_per_device); + } + if (config.init_mem_pool_size_mb) { + raft::device_resources_manager::set_init_mem_pool_size(*(config.init_mem_pool_size_mb) << 20); + } + if (config.max_mem_pool_size_mb) { + if (*config.max_mem_pool_size_mb > 0) { + raft::device_resources_manager::set_max_mem_pool_size(*(config.max_mem_pool_size_mb) << 20); + } + } else { + raft::device_resources_manager::set_max_mem_pool_size(std::nullopt); + } + if (config.max_workspace_size_mb) { + raft::device_resources_manager::set_workspace_allocation_limit(*(config.max_workspace_size_mb) << 20); + } + auto device_count = []() { + auto result = 0; + RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); + RAFT_EXPECTS(result != 0, "No CUDA devices found"); + return result; + }(); + + for (auto device_id = 0; device_id < device_count; ++device_id) { + auto scoped_device = raft::device_setter{device_id}; + auto workspace_size = std::size_t{}; + if (config.max_workspace_size_mb) { + workspace_size = *(config.max_workspace_size_mb) << 20; + } else { + auto free_mem = std::size_t{}; + auto total_mem = std::size_t{}; + RAFT_CUDA_TRY_NO_THROW(cudaMemGetInfo(&free_mem, &total_mem)); + // Heuristic: If workspace size is not explicitly specified, use half of free memory or a quarter of total + // memory, whichever is larger + workspace_size = std::max(free_mem / std::size_t{2}, total_mem / std::size_t{4}); + } + if(workspace_size > std::size_t{}) { + raft::device_resources_manager::set_workspace_memory_resource( + raft::resource::workspace_resource_factory::default_pool_resource( + workspace_size + ), + device_id + ); + } + } + } + ); +} + +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_initialization.hpp b/src/common/raft/integration/raft_initialization.hpp new file mode 100644 index 000000000..dbb231166 --- /dev/null +++ b/src/common/raft/integration/raft_initialization.hpp @@ -0,0 +1,30 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +namespace raft_knowhere { +struct raft_configuration { + std::size_t streams_per_device = std::size_t{16}; + std::size_t stream_pools_per_device = std::size_t{}; + std::optional stream_pool_size = std::nullopt; + std::optional init_mem_pool_size_mb = std::nullopt; + std::optional max_mem_pool_size_mb = std::nullopt; + std::optional max_workspace_size_mb = std::nullopt; +}; + +void initialize_raft(raft_configuration const& config); +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_config.hpp b/src/common/raft/integration/raft_knowhere_config.hpp new file mode 100644 index 000000000..475eb1332 --- /dev/null +++ b/src/common/raft/integration/raft_knowhere_config.hpp @@ -0,0 +1,118 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +#include "common/raft/proto/raft_index_kind.hpp" + +namespace raft_knowhere { +// This struct includes all parameters that may be passed to underlying RAFT +// indexes. It is designed to not expose ANY RAFT types in order to cleanly +// separate RAFT from knowhere headers. +struct raft_knowhere_config { + raft_proto::raft_index_kind index_type; + int k = 10; + + // Common Parameters + std::string metric_type = std::string{"L2"}; + float metric_arg = 2.0f; + bool add_data_on_build = true; + float refine_ratio = 1.0f; + + // Shared IVF Parameters + std::optional nlist = std::nullopt; + std::optional nprobe = std::nullopt; + std::optional kmeans_n_iters = std::nullopt; + std::optional kmeans_trainset_fraction = std::nullopt; + + // IVF Flat only Parameters + std::optional adaptive_centers = std::nullopt; + + // IVFPQ only Parameters + std::optional m = std::nullopt; + std::optional nbits = std::nullopt; + std::optional codebook_kind = std::nullopt; + std::optional force_random_rotation = std::nullopt; + std::optional conservative_memory_allocation = std::nullopt; + std::optional lookup_table_dtype = std::nullopt; + std::optional internal_distance_dtype = std::nullopt; + std::optional preferred_shmem_carveout = std::nullopt; + + // CAGRA Parameters + std::optional intermediate_graph_degree = std::nullopt; + std::optional graph_degree = std::nullopt; + std::optional itopk_size = std::nullopt; + std::optional max_queries = std::nullopt; + std::optional build_algo = std::nullopt; + std::optional search_algo = std::nullopt; + std::optional team_size = std::nullopt; + std::optional search_width = std::nullopt; + std::optional min_iterations = std::nullopt; + std::optional max_iterations = std::nullopt; + std::optional thread_block_size = std::nullopt; + std::optional hashmap_mode = std::nullopt; + std::optional hashmap_min_bitlen = std::nullopt; + std::optional hashmap_max_fill_rate = std::nullopt; + std::optional nn_descent_niter = std::nullopt; +}; + +// The following function provides a single source of truth for default values +// of RAFT index configurations. +[[nodiscard]] inline auto validate_raft_knowhere_config(raft_knowhere_config config) { + if (config.index_type == raft_proto::raft_index_kind::ivf_flat || config.index_type == raft_proto::raft_index_kind::ivf_pq) { + config.add_data_on_build = true; + config.nlist = config.nlist.value_or(128); + config.nprobe = config.nprobe.value_or(8); + config.kmeans_n_iters = config.kmeans_n_iters.value_or(20); + config.kmeans_trainset_fraction = config.kmeans_trainset_fraction.value_or(0.5f); + } + if (config.index_type == raft_proto::raft_index_kind::ivf_flat) { + config.adaptive_centers = config.adaptive_centers.value_or(false); + } + if (config.index_type == raft_proto::raft_index_kind::ivf_pq) { + config.m = config.m.value_or(0); + config.nbits = config.nbits.value_or(8); + config.codebook_kind = config.codebook_kind.value_or("PER_SUBSPACE"); + config.force_random_rotation = config.force_random_rotation.value_or(false); + config.conservative_memory_allocation = config.conservative_memory_allocation.value_or(false); + config.lookup_table_dtype = config.lookup_table_dtype.value_or("CUDA_R_32F"); + config.internal_distance_dtype = config.internal_distance_dtype.value_or("CUDA_R_32F"); + config.preferred_shmem_carveout = config.preferred_shmem_carveout.value_or(1.0f); + } + if (config.index_type == raft_proto::raft_index_kind::cagra) { + config.add_data_on_build = true; + config.intermediate_graph_degree = config.intermediate_graph_degree.value_or(128); + config.graph_degree = config.graph_degree.value_or(64); + config.itopk_size = config.itopk_size.value_or(64); + config.max_queries = config.max_queries.value_or(0); + config.build_algo = config.build_algo.value_or("IVF_PQ"); + config.search_algo = config.search_algo.value_or("AUTO"); + config.team_size = config.team_size.value_or(0); + config.search_width = config.search_width.value_or(1); + config.min_iterations = config.min_iterations.value_or(0); + config.max_iterations = config.max_iterations.value_or(0); + config.thread_block_size = config.thread_block_size.value_or(0); + config.hashmap_mode = config.hashmap_mode.value_or("AUTO"); + config.hashmap_min_bitlen = config.hashmap_min_bitlen.value_or(0); + config.hashmap_max_fill_rate = config.hashmap_max_fill_rate.value_or(0.5f); + config.nn_descent_niter = config.nn_descent_niter.value_or(20); + } + return config; +} + +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_index.cuh b/src/common/raft/integration/raft_knowhere_index.cuh new file mode 100644 index 000000000..e64830356 --- /dev/null +++ b/src/common/raft/integration/raft_knowhere_index.cuh @@ -0,0 +1,620 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/proto/raft_index.cuh" +#include "common/raft/integration/raft_knowhere_index.hpp" + +namespace raft_knowhere { +namespace detail { + +// This helper struct maps the generic type of RAFT index to the specific +// instantiation of that index used within knowhere. +template +struct raft_index_type_mapper : std::false_type { +}; + +template <> +struct raft_index_type_mapper : std::true_type { + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using type = raft_proto::raft_index; + using underlying_index_type = typename type::vector_index_type; + using index_params_type = typename type::index_params_type; + using search_params_type = typename type::search_params_type; +}; +template <> +struct raft_index_type_mapper : std::true_type { + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using type = raft_proto::raft_index; + using underlying_index_type = typename type::vector_index_type; + using index_params_type = typename type::index_params_type; + using search_params_type = typename type::search_params_type; +}; +template <> +struct raft_index_type_mapper : std::true_type { + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using type = raft_proto::raft_index; + using underlying_index_type = typename type::vector_index_type; + using index_params_type = typename type::index_params_type; + using search_params_type = typename type::search_params_type; +}; + +} // namespace detail + +template +using raft_index_t = typename detail::raft_index_type_mapper::type; + +template +using raft_index_params_t = typename detail::raft_index_type_mapper::index_params_type; +template +using raft_search_params_t = typename detail::raft_index_type_mapper::search_params_type; + +// Metrics are passed between knowhere and RAFT as strings to avoid tight +// coupling between the implementation details of either one. +[[nodiscard]] inline auto metric_string_to_raft_distance_type(std::string const& metric_string) { + auto result = raft::distance::DistanceType::L2Expanded; + if (metric_string == "L2") { + result = raft::distance::DistanceType::L2Expanded; + } else if (metric_string == "L2SqrtExpanded") { + result = raft::distance::DistanceType::L2SqrtExpanded; + } else if (metric_string == "CosineExpanded") { + result = raft::distance::DistanceType::CosineExpanded; + } else if (metric_string == "L1") { + result = raft::distance::DistanceType::L1; + } else if (metric_string == "L2Unexpanded") { + result = raft::distance::DistanceType::L2Unexpanded; + } else if (metric_string == "L2SqrtUnexpanded") { + result = raft::distance::DistanceType::L2SqrtUnexpanded; + } else if (metric_string == "IP") { + result = raft::distance::DistanceType::InnerProduct; + } else if (metric_string == "Linf") { + result = raft::distance::DistanceType::Linf; + } else if (metric_string == "Canberra") { + result = raft::distance::DistanceType::Canberra; + } else if (metric_string == "LpUnexpanded") { + result = raft::distance::DistanceType::LpUnexpanded; + } else if (metric_string == "CorrelationExpanded") { + result = raft::distance::DistanceType::CorrelationExpanded; + } else if (metric_string == "JACCARD") { + result = raft::distance::DistanceType::JaccardExpanded; + } else if (metric_string == "HellingerExpanded") { + result = raft::distance::DistanceType::HellingerExpanded; + } else if (metric_string == "Haversine") { + result = raft::distance::DistanceType::Haversine; + } else if (metric_string == "BrayCurtis") { + result = raft::distance::DistanceType::BrayCurtis; + } else if (metric_string == "JensenShannon") { + result = raft::distance::DistanceType::JensenShannon; + } else if (metric_string == "HAMMING") { + result = raft::distance::DistanceType::HammingUnexpanded; + } else if (metric_string == "KLDivergence") { + result = raft::distance::DistanceType::KLDivergence; + } else if (metric_string == "RusselRaoExpanded") { + result = raft::distance::DistanceType::RusselRaoExpanded; + } else if (metric_string == "DiceExpanded") { + result = raft::distance::DistanceType::DiceExpanded; + } else if (metric_string == "Precomputed") { + result = raft::distance::DistanceType::Precomputed; + } else { + RAFT_FAIL("Unrecognized metric type %s", metric_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto codebook_string_to_raft_codebook_gen(std::string const& codebook_string) { + auto result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + if (codebook_string == "PER_SUBSPACE") { + result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + } else if (codebook_string == "PER_CLUSTER") { + result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + } else { + RAFT_FAIL("Unrecognized codebook type %s", codebook_string.c_str()); + } + return result; +} +[[nodiscard]] inline auto build_algo_string_to_cagra_build_algo(std::string + const& algo_string) { + auto result = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + if (algo_string == "IVF_PQ") { + result = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + } else if (algo_string == "NN_DESCENT") { + result = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } else { + RAFT_FAIL("Unrecognized CAGRA build algo %s", algo_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto search_algo_string_to_cagra_search_algo(std::string + const& algo_string) { + auto result = raft::neighbors::cagra::search_algo::AUTO; + if (algo_string == "SINGLE_CTA") { + result = raft::neighbors::cagra::search_algo::SINGLE_CTA; + } else if (algo_string == "MULTI_CTA") { + result = raft::neighbors::cagra::search_algo::MULTI_CTA; + } else if (algo_string == "MULTI_KERNEL") { + result = raft::neighbors::cagra::search_algo::MULTI_KERNEL; + } else if (algo_string == "AUTO") { + result = raft::neighbors::cagra::search_algo::AUTO; + } else { + RAFT_FAIL("Unrecognized CAGRA search algo %s", algo_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto hashmap_mode_string_to_cagra_hashmap_mode(std::string + const& mode_string) { + auto result = raft::neighbors::cagra::hash_mode::AUTO; + if (mode_string == "HASH") { + result = raft::neighbors::cagra::hash_mode::HASH; + } else if (mode_string == "SMALL") { + result = raft::neighbors::cagra::hash_mode::SMALL; + } else if (mode_string == "AUTO") { + result = raft::neighbors::cagra::hash_mode::AUTO; + } else { + RAFT_FAIL("Unrecognized CAGRA hash mode %s", mode_string.c_str()); + } + return result; +} + +[[nodiscard]] inline auto dtype_string_to_cuda_dtype(std::string + const& dtype_string) { + auto result = CUDA_R_32F; + if (dtype_string == "CUDA_R_16F") { + result = CUDA_R_16F; + } else if (dtype_string == "CUDA_C_16F") { + result = CUDA_C_16F; + } else if (dtype_string == "CUDA_R_16BF") { + result = CUDA_R_16BF; + } else if (dtype_string == "CUDA_R_32F") { + result = CUDA_R_32F; + } else if (dtype_string == "CUDA_C_32F") { + result = CUDA_C_32F; + } else if (dtype_string == "CUDA_R_64F") { + result = CUDA_R_64F; + } else if (dtype_string == "CUDA_C_64F") { + result = CUDA_C_64F; + } else if (dtype_string == "CUDA_R_8I") { + result = CUDA_R_8I; + } else if (dtype_string == "CUDA_C_8I") { + result = CUDA_C_8I; + } else if (dtype_string == "CUDA_R_8U") { + result = CUDA_R_8U; + } else if (dtype_string == "CUDA_C_8U") { + result = CUDA_C_8U; + } else if (dtype_string == "CUDA_R_32I") { + result = CUDA_R_32I; + } else if (dtype_string == "CUDA_C_32I") { + result = CUDA_C_32I; + } else if (dtype_string == "CUDA_R_8F_E4M3") { + result = CUDA_R_8F_E4M3; + } else if (dtype_string == "CUDA_R_8F_E5M2") { + result = CUDA_R_8F_E5M2; + } else { + RAFT_FAIL("Unrecognized dtype %s", dtype_string.c_str()); + } + return result; +} + +// Given a generic config without RAFT symbols, convert to RAFT index build +// parameters +template +[[nodiscard]] auto config_to_index_params(raft_knowhere_config const& raw_config) { + RAFT_EXPECTS(raw_config.index_type == IndexKind, "Incorrect index type for this index"); + auto config = validate_raft_knowhere_config(raw_config); + auto result = raft_index_params_t{}; + + result.metric = metric_string_to_raft_distance_type(config.metric_type); + result.metric_arg = config.metric_arg; + result.add_data_on_build = config.add_data_on_build; + + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat || IndexKind == + raft_proto::raft_index_kind::ivf_pq) { + result.n_lists = *(config.nlist); + result.kmeans_n_iters = *(config.kmeans_n_iters); + result.kmeans_trainset_fraction = *(config.kmeans_trainset_fraction); + result.conservative_memory_allocation = *(config.conservative_memory_allocation); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat) { + result.adaptive_centers = *(config.adaptive_centers); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_pq) { + result.pq_dim = *(config.m); + result.pq_bits = *(config.nbits); + result.codebook_kind = codebook_string_to_raft_codebook_gen(*(config.codebook_kind)); + result.force_random_rotation = *(config.force_random_rotation); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::cagra) { + result.intermediate_graph_degree = *(config.intermediate_graph_degree); + result.graph_degree = *(config.graph_degree); + result.build_algo = build_algo_string_to_cagra_build_algo(*(config.build_algo)); + result.nn_descent_niter = *(config.nn_descent_niter); + } + return result; +} + +// Given a generic config without RAFT symbols, convert to RAFT index search +// parameters +template +[[nodiscard]] auto config_to_search_params(raft_knowhere_config const& raw_config) { + RAFT_EXPECTS(raw_config.index_type == IndexKind, "Incorrect index type for this index"); + auto config = validate_raft_knowhere_config(raw_config); + auto result = raft_search_params_t{}; + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat || IndexKind + == raft_proto::raft_index_kind::ivf_pq) { + result.n_probes = *(config.nprobe); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_pq) { + result.lut_dtype = dtype_string_to_cuda_dtype(*(config.lookup_table_dtype)); + result.internal_distance_dtype = dtype_string_to_cuda_dtype(*(config.internal_distance_dtype)); + result.preferred_shmem_carveout = *(config.preferred_shmem_carveout); + } + if constexpr (IndexKind == raft_proto::raft_index_kind::cagra) { + result.max_queries = *(config.max_queries); + result.itopk_size = *(config.itopk_size); + result.max_iterations = *(config.max_iterations); + result.algo = search_algo_string_to_cagra_search_algo(*(config.search_algo)); + result.team_size = *(config.team_size); + result.search_width = *(config.search_width); + result.min_iterations = *(config.min_iterations); + result.thread_block_size = *(config.thread_block_size); + result.hashmap_mode = hashmap_mode_string_to_cagra_hashmap_mode(*(config.hashmap_mode)); + result.hashmap_min_bitlen = *(config.hashmap_min_bitlen); + result.hashmap_max_fill_rate = *(config.hashmap_max_fill_rate); + } + return result; +} + +inline auto select_device_id() { + auto static device_count = []() { + auto result = 0; + RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); + RAFT_EXPECTS(result != 0, "No CUDA devices found"); + return result; + }(); + auto static index_counter = std::atomic{0}; + // Use round-robin assignment to distribute indexes across devices + auto result = index_counter.fetch_add(1) % device_count; + return result; +} + +// This struct is used to connect knowhere to a RAFT index. The implementation +// is provided here, but this header should never be directly included in +// another knowhere header. This ensures that RAFT symbols are not exposed in +// any knowhere header. +template +struct raft_knowhere_index::impl { + auto static constexpr index_kind = IndexKind; + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using input_indexing_type = raft_input_indexing_t; + + impl() {} + + private: + using raft_index_type = raft_index_t; + + public: + auto is_trained() const { + return index_.has_value(); + } + [[nodiscard]] auto size() const { + auto result = std::int64_t{}; + if (is_trained()) { + result = index_->size(); + } + return result; + } + [[nodiscard]] auto dim() const { + auto result = std::int64_t{}; + if (is_trained()) { + result = index_->dim(); + } + return result; + } + + void train(raft_knowhere_config const& config, data_type const* data, + knowhere_indexing_type row_count, knowhere_indexing_type feature_count) { + auto scoped_device = raft::device_setter{device_id}; + auto index_params = config_to_index_params(config); + auto const& res = raft::device_resources_manager::get_device_resources(); + auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); + device_dataset_storage = raft::make_device_matrix(res, row_count, feature_count); + auto device_data = device_dataset_storage->view(); + raft::copy(res, device_data, host_data); + index_ = raft_index_type::template build(res, index_params, raft::make_const_mdspan(device_data)); + } + + void add(data_type const* data, knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, knowhere_indexing_type const* new_ids) { + if constexpr (index_kind == raft_proto::raft_index_kind::cagra) { + if (index_) { + RAFT_FAIL("CAGRA does not support adding vectors after training"); + } + } else if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq){ + if (index_) { + RAFT_FAIL("IVFPQ does not support adding vectors after training"); + } + } else { + if (index_) { + auto const& res = raft::device_resources_manager::get_device_resources(); + raft::resource::set_workspace_to_pool_resource(res); + auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); + device_dataset_storage = raft::make_device_matrix(res, row_count, feature_count); + auto device_data = device_dataset_storage->view(); + raft::copy(res, device_data, host_data); + auto device_ids_storage = std::optional>{}; + if (new_ids != nullptr) { + auto host_ids = raft::make_host_vector_view(new_ids, row_count); + device_ids_storage = raft::make_device_vector(res, row_count); + raft::copy(res, device_ids_storage->view(), host_ids); + } + + if (device_ids_storage) { + index_ = raft_index_type::extend( + res, + raft::make_const_mdspan(device_data), + std::make_optional(raft::make_const_mdspan(device_ids_storage->view())), + *index_ + ); + } else { + index_ = raft_index_type::extend( + res, + raft::make_const_mdspan(device_data), + std::optional>{}, + *index_ + ); + } + } else { + RAFT_FAIL("Index has not yet been trained"); + } + } + } + + auto search( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_bitset_data_type const * bitset_data, + knowhere_bitset_indexing_type bitset_byte_size, + knowhere_bitset_indexing_type bitset_size + ) const { + auto scoped_device = raft::device_setter{device_id}; + auto const& res = raft::device_resources_manager::get_device_resources(); + auto k = knowhere_indexing_type(config.k); + auto search_params = config_to_search_params(config); + + auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); + auto device_data_storage = raft::make_device_matrix(res, row_count, feature_count); + raft::copy(res, device_data_storage.view(), host_data); + + auto device_bitset = std::optional>{}; + + if(bitset_data != nullptr && bitset_byte_size != 0) { + device_bitset = raft::core::bitset(res, bitset_size); + raft::copy(res, device_bitset->to_mdspan(), raft::make_host_vector_view(bitset_data, bitset_byte_size)); + device_bitset->flip(res); + } + + auto output_size = row_count * k; + auto ids = std::unique_ptr(new knowhere_indexing_type[output_size]); + auto distances = std::unique_ptr(new knowhere_data_type[output_size]); + + auto host_ids = raft::make_host_matrix_view(ids.get(), row_count, k); + auto host_distances = raft::make_host_matrix_view(distances.get(), row_count, k); + + auto device_ids_storage = raft::make_device_matrix(res, row_count, k); + auto device_distances_storage = raft::make_device_matrix(res, row_count, k); + auto device_ids = device_ids_storage.view(); + auto device_distances = device_distances_storage.view(); + + RAFT_EXPECTS(index_, "Index has not yet been trained"); + auto dataset_view = device_dataset_storage ? + std::make_optional(device_dataset_storage->view()) : + std::optional>{}; + + + if (device_bitset) { + raft_index_type::search( + res, + *index_, + search_params, + raft::make_const_mdspan(device_data_storage.view()), + device_ids, + device_distances, + config.refine_ratio, + input_indexing_type{}, + dataset_view, + raft::neighbors::filtering::bitset_filter{ + device_bitset->view() + } + ); + } else { + raft_index_type::search( + res, + *index_, + search_params, + raft::make_const_mdspan(device_data_storage.view()), + device_ids, + device_distances, + config.refine_ratio, + input_indexing_type{}, + dataset_view + ); + } + if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq) { + thrust::replace( + res.get_thrust_policy(), + thrust::device_ptr(device_ids.data_handle()), + thrust::device_ptr(device_ids.data_handle() + output_size), + raft::neighbors::ivf_pq::kOutOfBoundsRecord, + indexing_type{-1} + ); + } + raft::copy(res, host_ids, device_ids); + raft::copy(res, host_distances, device_distances); + return std::make_tuple(ids.release(), distances.release()); + } + void range_search() const { + RAFT_FAIL("Range search not yet implemented for RAFT indexes"); + } + void get_vector_by_id() const { + RAFT_FAIL("Vector reconstruction not yet implemented for RAFT indexes"); + } + void serialize( + std::ostream& os + ) const { + auto scoped_device = raft::device_setter{device_id}; + auto const& res = raft::device_resources_manager::get_device_resources(); + RAFT_EXPECTS(index_, "Index has not yet been trained"); + raft_index_type::template serialize(res, os, *index_); + } + auto static deserialize( + std::istream& is + ) { + auto new_device_id = select_device_id(); + auto scoped_device = raft::device_setter{new_device_id}; + auto const& res = raft::device_resources_manager::get_device_resources(); + return std::make_unique::impl>( + raft_index_type::template deserialize(res, is), + new_device_id + ); + } + void synchronize() const { + auto scoped_device = raft::device_setter{device_id}; + raft::device_resources_manager::get_device_resources().sync_stream(); + } + impl(raft_index_type&& index, int new_device_id) : index_{std::move(index)}, + device_id{new_device_id} {} + + private: + std::optional index_ = std::nullopt; + int device_id = select_device_id(); + std::optional> device_dataset_storage = std::nullopt; +}; + +template +raft_knowhere_index::raft_knowhere_index() : pimpl{new + raft_knowhere_index::impl()} {} + +template +raft_knowhere_index::~raft_knowhere_index() = default; + +template +raft_knowhere_index::raft_knowhere_index(raft_knowhere_index&& other) : pimpl{std::move(other.pimpl)} {} + +template +raft_knowhere_index& raft_knowhere_index::operator=(raft_knowhere_index&& other) { + pimpl = std::move(other.pimpl); + return *this; +} + +template +bool raft_knowhere_index::is_trained() const { + return pimpl->is_trained(); +} + +template +std::int64_t raft_knowhere_index::size() const { + return pimpl->size(); +} + +template +std::int64_t raft_knowhere_index::dim() const { + return pimpl->dim(); +} + +template +void raft_knowhere_index::train( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count +) { + return pimpl->train(config, data, row_count, feature_count); +} +template +void raft_knowhere_index::add( + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_indexing_type const* new_ids +) { + return pimpl->add(data, row_count, feature_count, new_ids); +} +template +std::tuple raft_knowhere_index::search( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_bitset_data_type const* bitset_data, + knowhere_bitset_indexing_type bitset_byte_size, + knowhere_bitset_indexing_type bitset_size +) const { + return pimpl->search(config, data, row_count, feature_count, bitset_data, bitset_byte_size, bitset_size); +} + +template +void raft_knowhere_index::range_search() const { + return pimpl->range_search(); +} + +template +void raft_knowhere_index::get_vector_by_id() const { + return pimpl->get_vector_by_id(); +} + +template +void raft_knowhere_index::serialize(std::ostream& os) const { + return pimpl->serialize(os); +} + +template +raft_knowhere_index +raft_knowhere_index::deserialize(std::istream& is) { + return raft_knowhere_index(raft_knowhere_index::impl::deserialize(is)); +} + +template +void raft_knowhere_index::synchronize() const { + return pimpl->synchronize(); +} + +} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_index.hpp b/src/common/raft/integration/raft_knowhere_index.hpp new file mode 100644 index 000000000..8082d7623 --- /dev/null +++ b/src/common/raft/integration/raft_knowhere_index.hpp @@ -0,0 +1,125 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/integration/raft_knowhere_config.hpp" + +namespace raft_knowhere { + +using knowhere_data_type = float; +using knowhere_indexing_type = std::int64_t; +using knowhere_bitset_data_type = std::uint8_t; +using knowhere_bitset_indexing_type = std::uint32_t; + +namespace detail { + +template +struct raft_io_type_mapper : std::false_type { +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::int64_t; + using input_indexing_type = std::int64_t; +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::int64_t; + using input_indexing_type = std::int64_t; +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::uint32_t; + using input_indexing_type = std::int64_t; +}; + +} // namespace detail + +template +using raft_data_t = typename detail::raft_io_type_mapper::data_type; + +template +using raft_indexing_t = typename detail::raft_io_type_mapper::indexing_type; + +template +using raft_input_indexing_t = typename detail::raft_io_type_mapper::input_indexing_type; + +template +struct raft_knowhere_index { + + auto static constexpr index_kind = IndexKind; + + using data_type = raft_data_t; + using indexing_type = raft_indexing_t; + using input_indexing_type = raft_input_indexing_t; + + raft_knowhere_index(); + ~raft_knowhere_index(); + + raft_knowhere_index(raft_knowhere_index&& other); + raft_knowhere_index& operator=(raft_knowhere_index&& other); + + bool is_trained() const; + std::int64_t size() const; + std::int64_t dim() const; + + void train( + raft_knowhere_config const&, + data_type const*, + knowhere_indexing_type, + knowhere_indexing_type + ); + void add( + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_indexing_type const* new_ids=nullptr + ); + std::tuple search( + raft_knowhere_config const& config, + data_type const* data, + knowhere_indexing_type row_count, + knowhere_indexing_type feature_count, + knowhere_bitset_data_type const * bitset_data = nullptr, + knowhere_bitset_indexing_type bitset_byte_size = knowhere_bitset_indexing_type{}, + knowhere_bitset_indexing_type bitset_size = knowhere_bitset_indexing_type{} + ) const; + void range_search() const; + void get_vector_by_id() const; + void serialize(std::ostream& os) const; + static raft_knowhere_index deserialize(std::istream& is); + void synchronize() const; + private: + // Use a private implementation to completely separate knowhere headers from + // RAFT headers + struct impl; + std::unique_ptr pimpl; + + raft_knowhere_index(std::unique_ptr&& new_pimpl) : pimpl{std::move(new_pimpl)} {} +}; + +extern template struct raft_knowhere_index; +extern template struct raft_knowhere_index; +extern template struct raft_knowhere_index; + +} diff --git a/src/common/raft/proto/ivf_to_sample_filter.cuh b/src/common/raft/proto/ivf_to_sample_filter.cuh new file mode 100644 index 000000000..f7a4c0dac --- /dev/null +++ b/src/common/raft/proto/ivf_to_sample_filter.cuh @@ -0,0 +1,39 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +namespace raft_proto { + +template +struct ivf_to_sample_filter { + const index_t* const* inds_ptrs_; + const filter_t next_filter_; + + ivf_to_sample_filter(const index_t* const* inds_ptrs, const filter_t next_filter) + : inds_ptrs_{inds_ptrs}, next_filter_{next_filter} {} + + inline __host__ __device__ bool operator()( + // query index + const uint32_t query_ix, + // the current inverted list index + const uint32_t cluster_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + return next_filter_(query_ix, inds_ptrs_[cluster_ix][sample_ix]); + } +}; + +} // namespace raft_proto diff --git a/src/common/raft/proto/raft_index.cuh b/src/common/raft/proto/raft_index.cuh new file mode 100644 index 000000000..3362bb8f6 --- /dev/null +++ b/src/common/raft/proto/raft_index.cuh @@ -0,0 +1,562 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/raft/proto/raft_index_kind.hpp" +#include "common/raft/proto/ivf_to_sample_filter.cuh" + +namespace raft_proto { + +auto static const RAFT_NAME = raft::RAFT_NAME; + +namespace detail { +template typename index_template> +struct template_matches_index_kind : std::false_type{}; + +template<> +struct template_matches_index_kind : std::true_type{}; + +template<> +struct template_matches_index_kind : std::true_type{}; + +template<> +struct template_matches_index_kind : std::true_type{}; + +template typename index_template> +auto static constexpr template_matches_index_kind_v = template_matches_index_kind::value; + +} + +template