diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 920249172f..e795a02d46 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -61,14 +61,9 @@ struct pointer_residency_count { auto [on_device, on_host] = pointer_residency_count::run(ptrs...); cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - switch (attr.type) { - case cudaMemoryTypeUnregistered: return std::make_tuple(on_device, on_host + 1); - case cudaMemoryTypeHost: - return std::make_tuple(on_device + int(attr.devicePointer == ptr), on_host + 1); - case cudaMemoryTypeDevice: return std::make_tuple(on_device + 1, on_host); - case cudaMemoryTypeManaged: return std::make_tuple(on_device + 1, on_host + 1); - default: return std::make_tuple(on_device, on_host); - } + if (attr.devicePointer || attr.type == cudaMemoryTypeDevice) { ++on_device; } + if (attr.hostPointer || attr.type == cudaMemoryTypeUnregistered) { ++on_host; } + return std::make_tuple(on_device, on_host); } }; diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index f9e7f521be..48b8525cad 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -223,12 +223,13 @@ void print_vector(const char* variable_name, const T* ptr, size_t componentsCoun { cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - if (attr.hostPointer != nullptr) { + if (attr.hostPointer) { print_host_vector(variable_name, reinterpret_cast(attr.hostPointer), componentsCount, out); } else if (attr.type == cudaMemoryTypeUnregistered) { print_host_vector(variable_name, ptr, componentsCount, out); } else { - print_device_vector(variable_name, ptr, componentsCount, out); + print_device_vector( + variable_name, reinterpret_cast(attr.devicePointer), componentsCount, out); } } /** @} */