diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index ec8ca827a..2a7e70522 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -42,3 +42,11 @@ target_link_libraries(CAGRA_C_EXAMPLE PRIVATE cuvs::c_api $") target_link_libraries(L2_C_EXAMPLE PRIVATE cuvs::c_api $) + +add_executable(IVF_FLAT_C_EXAMPLE src/ivf_flat_c_example.c) +target_include_directories(IVF_FLAT_C_EXAMPLE PUBLIC "$") +target_link_libraries(IVF_FLAT_C_EXAMPLE PRIVATE cuvs::c_api $) + +add_executable(IVF_PQ_C_EXAMPLE src/ivf_pq_c_example.c) +target_include_directories(IVF_PQ_C_EXAMPLE PUBLIC "$") +target_link_libraries(IVF_PQ_C_EXAMPLE PRIVATE cuvs::c_api $) diff --git a/examples/c/src/common.h b/examples/c/src/common.h new file mode 100644 index 000000000..60b9b73cf --- /dev/null +++ b/examples/c/src/common.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +/** + * @brief Initialize Tensor for kDLFloat. + * + * @param[in] t_d Pointer to a vector + * @param[in] t_shape[] Two-dimensional array, which stores the number of rows and columns of vectors. + * @param[out] t_tensor Stores the initialized DLManagedTensor. + */ +void float_tensor_initialize(float* t_d, int64_t t_shape[2], DLManagedTensor* t_tensor) { + t_tensor->dl_tensor.data = t_d; + t_tensor->dl_tensor.device.device_type = kDLCUDA; + t_tensor->dl_tensor.ndim = 2; + t_tensor->dl_tensor.dtype.code = kDLFloat; + t_tensor->dl_tensor.dtype.bits = 32; + t_tensor->dl_tensor.dtype.lanes = 1; + t_tensor->dl_tensor.shape = t_shape; + t_tensor->dl_tensor.strides = NULL; +} + +/** + * @brief Initialize Tensor for kDLInt. + * + * @param[in] t_d Pointer to a vector + * @param[in] t_shape[] Two-dimensional array, which stores the number of rows and columns of vectors. + * @param[out] t_tensor Stores the initialized DLManagedTensor. + */ +void int_tensor_initialize(int64_t* t_d, int64_t t_shape[], DLManagedTensor* t_tensor) { + t_tensor->dl_tensor.data = t_d; + t_tensor->dl_tensor.device.device_type = kDLCUDA; + t_tensor->dl_tensor.ndim = 2; + t_tensor->dl_tensor.dtype.code = kDLInt; + t_tensor->dl_tensor.dtype.bits = 64; + t_tensor->dl_tensor.dtype.lanes = 1; + t_tensor->dl_tensor.shape = t_shape; + t_tensor->dl_tensor.strides = NULL; +} + +/** + * @brief Fill a vector with random values. + * + * @param[out] Vec Pointer to a vector + * @param[in] n_rows the number of rows in the matrix. + * @param[in] n_cols the number of columns in the matrix. + * @param[in] min Minimum value among random values. + * @param[in] max Maximum value among random values. + */ +void generate_dataset(float * Vec,int n_rows, int n_cols, float min, float max) { + float scale; + float * ptr = Vec; + srand((unsigned int)time(NULL)); + for (int i = 0; i < n_rows; i++) { + for (int j = 0; j < n_cols; j++) { + scale = rand()/(float)RAND_MAX; + ptr = Vec + i * n_cols + j; + *ptr = min + scale * (max - min); + } + } +} + +/** + * @brief print the result. + * + * @param[in] neighbor Pointer to a neighbor vector + * @param[in] distances Pointer to a distances vector. + * @param[in] n_rows the number of rows in the matrix. + * @param[in] n_cols the number of columns in the matrix. + */ +void print_results(int64_t * neighbor, float* distances,int n_rows, int n_cols) { + int64_t * pn = neighbor; + float * pd = distances; + for (int i = 0; i < n_rows; ++i) { + printf("Query %d neighbor indices: =[", i); + for (int j = 0; j < n_cols; ++j) { + pn = neighbor + i * n_cols + j; + printf(" %ld", *pn); + } + printf("]\n"); + printf("Query %d neighbor distances: =[", i); + for (int j = 0; j < n_cols; ++j) { + pd = distances + i * n_cols + j; + printf(" %f", *pd); + } + printf("]\n"); + } +} + diff --git a/examples/c/src/ivf_flat_c_example.c b/examples/c/src/ivf_flat_c_example.c new file mode 100644 index 000000000..c068d04f8 --- /dev/null +++ b/examples/c/src/ivf_flat_c_example.c @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include "common.h" + +void ivf_flat_build_search_simple(cuvsResources_t *res, DLManagedTensor * dataset_tensor, DLManagedTensor * queries_tensor) { + // Create default index params + cuvsIvfFlatIndexParams_t index_params; + cuvsIvfFlatIndexParamsCreate(&index_params); + index_params->n_lists = 1024; // default value + index_params->kmeans_n_iters = 20; // default value + index_params->kmeans_trainset_fraction = 0.1; + //index_params->metric default is L2Expanded + + // Create IVF-Flat index + cuvsIvfFlatIndex_t index; + cuvsIvfFlatIndexCreate(&index); + + printf("Building IVF-Flat index\n"); + // Build the IVF-Flat Index + cuvsError_t build_status = cuvsIvfFlatBuild(*res, index_params, dataset_tensor, index); + if (build_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + cuvsIvfFlatIndexDestroy(index); + cuvsIvfFlatIndexParamsDestroy(index_params); + return; + } + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries_tensor->dl_tensor.shape[0]; + + //Allocate memory for `neighbors` and `distances` output + int64_t *neighbors_d; + float *distances_d; + cuvsRMMAlloc(*res, (void**) &neighbors_d, sizeof(int64_t) * n_queries * topk); + cuvsRMMAlloc(*res, (void**) &distances_d, sizeof(float) * n_queries * topk); + + DLManagedTensor neighbors_tensor; + int64_t neighbors_shape[2] = {n_queries, topk}; + int_tensor_initialize(neighbors_d, neighbors_shape, &neighbors_tensor); + + DLManagedTensor distances_tensor; + int64_t distances_shape[2] = {n_queries, topk}; + float_tensor_initialize(distances_d, distances_shape, &distances_tensor); + + // Create default search params + cuvsIvfFlatSearchParams_t search_params; + cuvsIvfFlatSearchParamsCreate(&search_params); + search_params->n_probes = 50; + + // Search the `index` built using `ivfFlatBuild` + cuvsError_t search_status = cuvsIvfFlatSearch(*res, search_params, index, + queries_tensor, &neighbors_tensor, &distances_tensor); + if (build_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + } + + int64_t *neighbors = (int64_t *)malloc(n_queries * topk * sizeof(int64_t)); + float *distances = (float *)malloc(n_queries * topk * sizeof(float)); + memset(neighbors, 0, n_queries * topk * sizeof(int64_t)); + memset(distances, 0, n_queries * topk * sizeof(float)); + + cudaMemcpy(neighbors, neighbors_d, sizeof(int64_t) * n_queries * topk, cudaMemcpyDefault); + cudaMemcpy(distances, distances_d, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + + print_results(neighbors, distances, 2, topk); + + free(distances); + free(neighbors); + + cuvsRMMFree(*res, neighbors_d, sizeof(int64_t) * n_queries * topk); + cuvsRMMFree(*res, distances_d, sizeof(float) * n_queries * topk); + + cuvsIvfFlatSearchParamsDestroy(search_params); + cuvsIvfFlatIndexDestroy(index); + cuvsIvfFlatIndexParamsDestroy(index_params); +} + +void ivf_flat_build_extend_search(cuvsResources_t *res, DLManagedTensor * trainset_tensor, DLManagedTensor * dataset_tensor, DLManagedTensor * queries_tensor) { + int64_t *data_indices_d; + int64_t n_dataset = dataset_tensor->dl_tensor.shape[0]; + cuvsRMMAlloc(*res, (void**) &data_indices_d, sizeof(int64_t) * n_dataset); + DLManagedTensor data_indices_tensor; + int64_t data_indices_shape[1] = {n_dataset}; + int_tensor_initialize(data_indices_d, data_indices_shape, &data_indices_tensor); + data_indices_tensor.dl_tensor.ndim = 1; + + printf("\nRun k-means clustering using the training set\n"); + + int64_t *data_indices = (int64_t *)malloc(n_dataset * sizeof(int64_t)); + int64_t * ptr = data_indices; + for (int i = 0; i < n_dataset; i++) { + *ptr = i; + ptr++; + } + ptr = NULL; + cudaMemcpy(data_indices_d, data_indices, sizeof(int64_t) * n_dataset, cudaMemcpyDefault); + + // Create default index params + cuvsIvfFlatIndexParams_t index_params; + cuvsIvfFlatIndexParamsCreate(&index_params); + index_params->n_lists = 100; + index_params->add_data_on_build = false; + //index_params->metric default is L2Expanded + + // Create IVF-Flat index + cuvsIvfFlatIndex_t index; + cuvsIvfFlatIndexCreate(&index); + + // Build the IVF-Flat Index + cuvsError_t build_status = cuvsIvfFlatBuild(*res, index_params, trainset_tensor, index); + if (build_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + cuvsIvfFlatIndexDestroy(index); + cuvsIvfFlatIndexParamsDestroy(index_params); + return; + } + + printf("Filling index with the dataset vectors\n"); + cuvsError_t extend_status = cuvsIvfFlatExtend(*res, dataset_tensor, &data_indices_tensor, index); + if (extend_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + return; + } + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries_tensor->dl_tensor.shape[0]; + + //Allocate memory for `neighbors` and `distances` output + int64_t *neighbors_d; + float *distances_d; + cuvsRMMAlloc(*res, (void**) &neighbors_d, sizeof(int64_t) * n_queries * topk); + cuvsRMMAlloc(*res, (void**) &distances_d, sizeof(float) * n_queries * topk); + + DLManagedTensor neighbors_tensor; + int64_t neighbors_shape[2] = {n_queries, topk}; + int_tensor_initialize(neighbors_d, neighbors_shape, &neighbors_tensor); + + DLManagedTensor distances_tensor; + int64_t distances_shape[2] = {n_queries, topk}; + float_tensor_initialize(distances_d, distances_shape, &distances_tensor); + + // Create default search params + cuvsIvfFlatSearchParams_t search_params; + cuvsIvfFlatSearchParamsCreate(&search_params); + search_params->n_probes = 10; + + // Search the `index` built using `ivfFlatBuild` + cuvsError_t search_status = cuvsIvfFlatSearch(*res, search_params, index, + queries_tensor, &neighbors_tensor, &distances_tensor); + if (search_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + exit(-1); + } + + int64_t *neighbors = (int64_t *)malloc(n_queries * topk * sizeof(int64_t)); + float *distances = (float *)malloc(n_queries * topk * sizeof(float)); + memset(neighbors, 0, n_queries * topk * sizeof(int64_t)); + memset(distances, 0, n_queries * topk * sizeof(float)); + + cudaMemcpy(neighbors, neighbors_d, sizeof(int64_t) * n_queries * topk, cudaMemcpyDefault); + cudaMemcpy(distances, distances_d, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + + print_results(neighbors, distances, 2, topk); + + free(distances); + free(neighbors); + free(data_indices); + cuvsRMMFree(*res, data_indices_d, sizeof(int64_t) * n_dataset); + cuvsRMMFree(*res, neighbors_d, sizeof(int64_t) * n_queries * topk); + cuvsRMMFree(*res, distances_d, sizeof(float) * n_queries * topk); + + cuvsIvfFlatSearchParamsDestroy(search_params); + cuvsIvfFlatIndexDestroy(index); + cuvsIvfFlatIndexParamsDestroy(index_params); +} + +int main() { + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 3; + int64_t n_queries = 10; + float *dataset = (float *)malloc(n_samples * n_dim * sizeof(float)); + float *queries = (float *)malloc(n_queries * n_dim * sizeof(float)); + generate_dataset(dataset, n_samples, n_dim, -10.0, 10.0); + generate_dataset(queries, n_queries, n_dim, -1.0, 1.0); + + // Create a cuvsResources_t object + cuvsResources_t res; + cuvsResourcesCreate(&res); + + // Allocate memory for `queries` + float *dataset_d; + cuvsRMMAlloc(res, (void**) &dataset_d, sizeof(float) * n_samples * n_dim); + // Use DLPack to represent `dataset_d` as a tensor + cudaMemcpy(dataset_d, dataset, sizeof(float) * n_samples * n_dim, cudaMemcpyDefault); + + DLManagedTensor dataset_tensor; + int64_t dataset_shape[2] = {n_samples,n_dim}; + float_tensor_initialize(dataset_d, dataset_shape, &dataset_tensor); + + // Allocate memory for `queries` + float *queries_d; + cuvsRMMAlloc(res, (void**) &queries_d, sizeof(float) * n_queries * n_dim); + + // Use DLPack to represent `queries` as tensors + cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * n_dim, cudaMemcpyDefault); + + DLManagedTensor queries_tensor; + int64_t queries_shape[2] = {n_queries, n_dim}; + float_tensor_initialize(queries_d, queries_shape, &queries_tensor); + + // Simple build and search example. + ivf_flat_build_search_simple(&res, &dataset_tensor, &queries_tensor); + + float *trainset_d; + int64_t n_trainset = n_samples * 0.1; + float *trainset = (float *)malloc(n_trainset * n_dim * sizeof(float)); + for (int i = 0; i < n_trainset; i++) { + for (int j = 0; j < n_dim; j++) { + *(trainset + i * n_dim + j) = *(dataset + i * n_dim + j); + } + } + cuvsRMMAlloc(res, (void**) &trainset_d, sizeof(float) * n_trainset * n_dim); + cudaMemcpy(trainset_d, trainset, sizeof(float) * n_trainset * n_dim, cudaMemcpyDefault); + DLManagedTensor trainset_tensor; + int64_t trainset_shape[2] = {n_trainset, n_dim}; + float_tensor_initialize(trainset_d, trainset_shape, &trainset_tensor); + + // Build and extend example. + ivf_flat_build_extend_search(&res, &trainset_tensor, &dataset_tensor, &queries_tensor); + + cuvsRMMFree(res, trainset_d, sizeof(float) * n_trainset * n_dim); + cuvsRMMFree(res, queries_d, sizeof(float) * n_queries * n_dim); + cuvsRMMFree(res, dataset_d, sizeof(float) * n_samples * n_dim); + cuvsResourcesDestroy(res); + free(trainset); + free(dataset); + free(queries); +} diff --git a/examples/c/src/ivf_pq_c_example.c b/examples/c/src/ivf_pq_c_example.c new file mode 100644 index 000000000..b6d6b485b --- /dev/null +++ b/examples/c/src/ivf_pq_c_example.c @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include "common.h" + +void ivf_pq_build_search(cuvsResources_t *res, DLManagedTensor * dataset_tensor, DLManagedTensor * queries_tensor) { + // Create default index params + cuvsIvfPqIndexParams_t index_params; + cuvsIvfPqIndexParamsCreate(&index_params); + index_params->n_lists = 1024; // default value + index_params->kmeans_trainset_fraction = 0.1; + //index_params->metric default is L2Expanded + index_params->pq_bits = 8; + index_params->pq_dim = 2; + + // Create IVF-PQ index + cuvsIvfPqIndex_t index; + cuvsIvfPqIndexCreate(&index); + + printf("Building IVF-PQ index\n"); + + // Build the IVF-PQ Index + cuvsError_t build_status = cuvsIvfPqBuild(*res, index_params, dataset_tensor, index); + if (build_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + cuvsIvfPqIndexDestroy(index); + cuvsIvfPqIndexParamsDestroy(index_params); + return; + } + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries_tensor->dl_tensor.shape[0]; + + //Allocate memory for `neighbors` and `distances` output + int64_t *neighbors_d; + float *distances_d; + cuvsRMMAlloc(*res, (void**) &neighbors_d, sizeof(int64_t) * n_queries * topk); + cuvsRMMAlloc(*res, (void**) &distances_d, sizeof(float) * n_queries * topk); + + DLManagedTensor neighbors_tensor; + int64_t neighbors_shape[2] = {n_queries, topk}; + int_tensor_initialize(neighbors_d, neighbors_shape, &neighbors_tensor); + + DLManagedTensor distances_tensor; + int64_t distances_shape[2] = {n_queries, topk}; + float_tensor_initialize(distances_d, distances_shape, &distances_tensor); + + // Create default search params + cuvsIvfPqSearchParams_t search_params; + cuvsIvfPqSearchParamsCreate(&search_params); + search_params->n_probes = 50; + search_params->internal_distance_dtype = CUDA_R_16F; + search_params->lut_dtype = CUDA_R_16F; + + // Search the `index` built using `cuvsIvfPqBuild` + cuvsError_t search_status = cuvsIvfPqSearch(*res, search_params, index, + queries_tensor, &neighbors_tensor, &distances_tensor); + if (search_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + exit(-1); + } + + int64_t *neighbors = (int64_t *)malloc(n_queries * topk * sizeof(int64_t)); + float *distances = (float *)malloc(n_queries * topk * sizeof(float)); + memset(neighbors, 0, n_queries * topk * sizeof(int64_t)); + memset(distances, 0, n_queries * topk * sizeof(float)); + + cudaMemcpy(neighbors, neighbors_d, sizeof(int64_t) * n_queries * topk, cudaMemcpyDefault); + cudaMemcpy(distances, distances_d, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + + printf("\nOriginal results:\n"); + print_results(neighbors, distances, 2, topk); + + // Re-ranking operation: refine the initial search results by computing exact distances + int64_t topk_refined = 7; + int64_t *neighbors_refined_d; + float *distances_refined_d; + cuvsRMMAlloc(*res, (void**) &neighbors_refined_d, sizeof(int64_t) * n_queries * topk_refined); + cuvsRMMAlloc(*res, (void**) &distances_refined_d, sizeof(float) * n_queries * topk_refined); + + DLManagedTensor neighbors_refined_tensor; + int64_t neighbors_refined_shape[2] = {n_queries, topk_refined}; + int_tensor_initialize(neighbors_refined_d, neighbors_refined_shape, &neighbors_refined_tensor); + + DLManagedTensor distances_refined_tensor; + int64_t distances_refined_shape[2] = {n_queries, topk_refined}; + float_tensor_initialize(distances_refined_d, distances_refined_shape, &distances_refined_tensor); + + // Note, refinement requires the original dataset and the queries. + // Don't forget to specify the same distance metric as used by the index. + cuvsError_t refine_status = cuvsRefine(*res, dataset_tensor, queries_tensor, + &neighbors_tensor, index_params->metric, + &neighbors_refined_tensor, &distances_refined_tensor); + if (refine_status != CUVS_SUCCESS) { + printf("%s.\n", cuvsGetLastErrorText()); + exit(-1); + } + + int64_t *neighbors_refine = (int64_t *)malloc(n_queries * topk_refined * sizeof(int64_t)); + float *distances_refine = (float *)malloc(n_queries * topk_refined * sizeof(float)); + memset(neighbors_refine, 0, n_queries * topk_refined * sizeof(int64_t)); + memset(distances_refine, 0, n_queries * topk_refined * sizeof(float)); + + cudaMemcpy(neighbors_refine, neighbors_refined_d, sizeof(int64_t) * n_queries * topk_refined, cudaMemcpyDefault); + cudaMemcpy(distances_refine, distances_refined_d, sizeof(float) * n_queries * topk_refined, cudaMemcpyDefault); + + printf("\nRefined results:\n"); + print_results(neighbors, distances, 2, topk_refined); + + free(distances_refine); + free(neighbors_refine); + + free(distances); + free(neighbors); + + cuvsRMMFree(*res, neighbors_refined_d, sizeof(int64_t) * n_queries * topk_refined); + cuvsRMMFree(*res, distances_refined_d, sizeof(float) * n_queries * topk_refined); + + cuvsRMMFree(*res, neighbors_d, sizeof(int64_t) * n_queries * topk); + cuvsRMMFree(*res, distances_d, sizeof(float) * n_queries * topk); + + cuvsIvfPqSearchParamsDestroy(search_params); + cuvsIvfPqIndexDestroy(index); + cuvsIvfPqIndexParamsDestroy(index_params); +} + +int main() { + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 3; + int64_t n_queries = 10; + float *dataset = (float *)malloc(n_samples * n_dim * sizeof(float)); + float *queries = (float *)malloc(n_queries * n_dim * sizeof(float)); + generate_dataset(dataset, n_samples, n_dim, -10.0, 10.0); + generate_dataset(queries, n_queries, n_dim, -1.0, 1.0); + + // Create a cuvsResources_t object + cuvsResources_t res; + cuvsResourcesCreate(&res); + + // Allocate memory for `queries` + float *dataset_d; + cuvsRMMAlloc(res, (void**) &dataset_d, sizeof(float) * n_samples * n_dim); + // Use DLPack to represent `dataset_d` as a tensor + cudaMemcpy(dataset_d, dataset, sizeof(float) * n_samples * n_dim, cudaMemcpyDefault); + + DLManagedTensor dataset_tensor; + int64_t dataset_shape[2] = {n_samples,n_dim}; + float_tensor_initialize(dataset_d, dataset_shape, &dataset_tensor); + + // Allocate memory for `queries` + float *queries_d; + cuvsRMMAlloc(res, (void**) &queries_d, sizeof(float) * n_queries * n_dim); + + // Use DLPack to represent `queries` as tensors + cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * n_dim, cudaMemcpyDefault); + + DLManagedTensor queries_tensor; + int64_t queries_shape[2] = {n_queries, n_dim}; + float_tensor_initialize(queries_d, queries_shape, &queries_tensor); + + // Simple build and search example. + ivf_pq_build_search(&res, &dataset_tensor, &queries_tensor); + + cuvsRMMFree(res, queries_d, sizeof(float) * n_queries * n_dim); + cuvsRMMFree(res, dataset_d, sizeof(float) * n_samples * n_dim); + cuvsResourcesDestroy(res); + free(dataset); + free(queries); +} diff --git a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py index 2b4213016..dbee6cd36 100644 --- a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py +++ b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py @@ -24,7 +24,7 @@ from pylibraft.common import DeviceResources from rmm.allocators.cupy import rmm_cupy_allocator -from cuvs.neighbors.brute_force import knn +from cuvs.neighbors.brute_force import build, search from .utils import memmap_bin_file, suffix_from_dtype, write_bin @@ -49,7 +49,7 @@ def choose_random_queries(dataset, n_queries): def calc_truth(dataset, queries, k, metric="sqeuclidean"): - handle = DeviceResources() + resources = DeviceResources() n_samples = dataset.shape[0] n = 500000 # batch size for processing neighbors i = 0 @@ -63,8 +63,9 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): X = cp.asarray(dataset[i : i + n_batch, :], cp.float32) - D, Ind = knn(X, queries, k, metric=metric, handle=handle) - handle.sync() + index = build(X, metric=metric, resources=resources) + D, Ind = search(index, queries, k, resources=resources) + resources.sync() D, Ind = cp.asarray(D), cp.asarray(Ind) Ind += i # shift neighbor index by offset i