From a51562489fe4e974d46d470bd6fba760d3cf3f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Thu, 28 Sep 2023 18:01:03 +0200 Subject: [PATCH] add sample_weight parameter to dbscan.fit (#5574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds the optional 'sample_weight' parameter to the dbscan fit function. Issue: #5556 CC @tfeher Authors: - Malte Förster (https://github.com/mfoerste4) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/5574 --- cpp/bench/sg/dbscan.cu | 3 +- cpp/examples/dbscan/dbscan_example.cpp | 3 +- cpp/include/cuml/cluster/dbscan.hpp | 10 +- cpp/src/dbscan/corepoints/compute.cuh | 10 +- cpp/src/dbscan/dbscan.cu | 14 ++- cpp/src/dbscan/dbscan.cuh | 5 +- cpp/src/dbscan/dbscan_api.cpp | 4 +- cpp/src/dbscan/runner.cuh | 54 +++++++++-- cpp/src/dbscan/vertexdeg/algo.cuh | 20 +++- cpp/src/dbscan/vertexdeg/pack.h | 4 + cpp/src/dbscan/vertexdeg/precomputed.cuh | 20 +++- cpp/src/dbscan/vertexdeg/runner.cuh | 6 +- cpp/test/sg/dbscan_test.cu | 115 ++++++++++++++++++++++- python/cuml/cluster/dbscan.pyx | 37 +++++++- python/cuml/cluster/dbscan_mg.pyx | 6 +- 15 files changed, 278 insertions(+), 33 deletions(-) diff --git a/cpp/bench/sg/dbscan.cu b/cpp/bench/sg/dbscan.cu index 938a2e5fb0..fc94bf8dda 100644 --- a/cpp/bench/sg/dbscan.cu +++ b/cpp/bench/sg/dbscan.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ class Dbscan : public BlobsFixture { raft::distance::L2SqrtUnexpanded, this->data.y.data(), this->core_sample_indices, + nullptr, dParams.max_bytes_per_batch); state.SetItemsProcessed(this->params.nrows * this->params.ncols); }); diff --git a/cpp/examples/dbscan/dbscan_example.cpp b/cpp/examples/dbscan/dbscan_example.cpp index 385aa66ae2..5a8fc01eca 100644 --- a/cpp/examples/dbscan/dbscan_example.cpp +++ b/cpp/examples/dbscan/dbscan_example.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -205,6 +205,7 @@ int main(int argc, char* argv[]) raft::distance::L2SqrtUnexpanded, d_labels, nullptr, + nullptr, max_bytes_per_batch, false); CUDA_RT_CALL(cudaMemcpyAsync( diff --git a/cpp/include/cuml/cluster/dbscan.hpp b/cpp/include/cuml/cluster/dbscan.hpp index e80b9a65dc..b41f50417c 100644 --- a/cpp/include/cuml/cluster/dbscan.hpp +++ b/cpp/include/cuml/cluster/dbscan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,6 +43,10 @@ namespace Dbscan { * indices of each core point. If the number of core points is less * than n_rows, the right will be padded with -1. Setting this to * NULL will prevent calculating the core sample indices + * @param[in] sample_weight (size n_rows) input array containing the + * weight of each sample to be taken instead of a plain sum to + * fulfill the min_pts criteria for core points. + * NULL will default to weights of 1 for all samples * @param[in] max_bytes_per_batch the maximum number of megabytes to be used for * each batch of the pairwise distance calculation. This enables the * trade off between memory usage and algorithm execution time. @@ -60,6 +64,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int* labels, int* core_sample_indices = nullptr, + float* sample_weight = nullptr, size_t max_bytes_per_batch = 0, int verbosity = CUML_LEVEL_INFO, bool opg = false); @@ -72,6 +77,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int* labels, int* core_sample_indices = nullptr, + double* sample_weight = nullptr, size_t max_bytes_per_batch = 0, int verbosity = CUML_LEVEL_INFO, bool opg = false); @@ -85,6 +91,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices = nullptr, + float* sample_weight = nullptr, size_t max_bytes_per_batch = 0, int verbosity = CUML_LEVEL_INFO, bool opg = false); @@ -97,6 +104,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices = nullptr, + double* sample_weight = nullptr, size_t max_bytes_per_batch = 0, int verbosity = CUML_LEVEL_INFO, bool opg = false); diff --git a/cpp/src/dbscan/corepoints/compute.cuh b/cpp/src/dbscan/corepoints/compute.cuh index 0cf8b14457..838b5b4d90 100644 --- a/cpp/src/dbscan/corepoints/compute.cuh +++ b/cpp/src/dbscan/corepoints/compute.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,16 +28,16 @@ namespace CorePoints { /** * Compute the core points from the vertex degrees and min_pts criterion * @param[in] handle cuML handle - * @param[in] vd Vertex degrees + * @param[in] vd Vertex degrees (optionally weighted) * @param[out] mask Boolean core point mask * @param[in] min_pts Core point criterion * @param[in] start_vertex_id First point of the batch * @param[in] batch_size Batch size * @param[in] stream CUDA stream */ -template +template void compute(const raft::handle_t& handle, - const Index_* vd, + const Values_* vd, bool* mask, Index_ min_pts, Index_ start_vertex_id, @@ -47,7 +47,7 @@ void compute(const raft::handle_t& handle, auto counting = thrust::make_counting_iterator(0); thrust::for_each( handle.get_thrust_policy(), counting, counting + batch_size, [=] __device__(Index_ idx) { - mask[idx + start_vertex_id] = vd[idx] >= min_pts; + mask[idx + start_vertex_id] = (Index_)vd[idx] >= min_pts; }); } diff --git a/cpp/src/dbscan/dbscan.cu b/cpp/src/dbscan/dbscan.cu index 4121321546..0d01e94271 100644 --- a/cpp/src/dbscan/dbscan.cu +++ b/cpp/src/dbscan/dbscan.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int* labels, int* core_sample_indices, + float* sample_weight, size_t max_bytes_per_batch, int verbosity, bool opg) @@ -45,6 +46,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -58,6 +60,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -72,6 +75,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int* labels, int* core_sample_indices, + double* sample_weight, size_t max_bytes_per_batch, int verbosity, bool opg) @@ -86,6 +90,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -99,6 +104,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -113,6 +119,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices, + float* sample_weight, size_t max_bytes_per_batch, int verbosity, bool opg) @@ -127,6 +134,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -140,6 +148,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -154,6 +163,7 @@ void fit(const raft::handle_t& handle, raft::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices, + double* sample_weight, size_t max_bytes_per_batch, int verbosity, bool opg) @@ -168,6 +178,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); @@ -181,6 +192,7 @@ void fit(const raft::handle_t& handle, metric, labels, core_sample_indices, + sample_weight, max_bytes_per_batch, handle.get_stream(), verbosity); diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index ef4be6f66f..2796955793 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,6 +104,7 @@ void dbscanFitImpl(const raft::handle_t& handle, raft::distance::DistanceType metric, Index_* labels, Index_* core_sample_indices, + T* sample_weight, size_t max_mbytes_per_batch, cudaStream_t stream, int verbosity) @@ -177,6 +178,7 @@ void dbscanFitImpl(const raft::handle_t& handle, min_pts, labels, core_sample_indices, + sample_weight, algo_vd, algo_adj, algo_ccl, @@ -198,6 +200,7 @@ void dbscanFitImpl(const raft::handle_t& handle, min_pts, labels, core_sample_indices, + sample_weight, algo_vd, algo_adj, algo_ccl, diff --git a/cpp/src/dbscan/dbscan_api.cpp b/cpp/src/dbscan/dbscan_api.cpp index 2f4b874623..061c1f6b02 100644 --- a/cpp/src/dbscan/dbscan_api.cpp +++ b/cpp/src/dbscan/dbscan_api.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,7 @@ cumlError_t cumlSpDbscanFit(cumlHandle_t handle, raft::distance::L2SqrtUnexpanded, labels, core_sample_indices, + NULL, max_bytes_per_batch, verbosity); } @@ -88,6 +89,7 @@ cumlError_t cumlDpDbscanFit(cumlHandle_t handle, raft::distance::L2SqrtUnexpanded, labels, core_sample_indices, + NULL, max_bytes_per_batch, verbosity); } diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index c433fa59e8..89e714e058 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -110,6 +110,7 @@ std::size_t run(const raft::handle_t& handle, Index_ min_pts, Index_* labels, Index_* core_indices, + const Type_f* sample_weight, int algo_vd, int algo_adj, int algo_ccl, @@ -146,6 +147,8 @@ std::size_t run(const raft::handle_t& handle, std::size_t ex_scan_size = raft::alignTo(sizeof(Index_) * batch_size, align); std::size_t row_cnt_size = raft::alignTo(sizeof(Index_) * batch_size, align); std::size_t labels_size = raft::alignTo(sizeof(Index_) * N, align); + std::size_t wght_sum_size = + sample_weight != nullptr ? raft::alignTo(sizeof(Type_f) * batch_size, align) : 0; Index_ MAX_LABEL = std::numeric_limits::max(); @@ -157,8 +160,8 @@ std::size_t run(const raft::handle_t& handle, (unsigned long)batch_size); if (workspace == NULL) { - auto size = - adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size + 2 * labels_size; + auto size = adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size + + 2 * labels_size + wght_sum_size; return size; } @@ -183,6 +186,11 @@ std::size_t run(const raft::handle_t& handle, temp += labels_size; Index_* work_buffer = (Index_*)temp; temp += labels_size; + Type_f* wght_sum = nullptr; + if (sample_weight != nullptr) { + wght_sum = (Type_f*)temp; + temp += wght_sum_size; + } // Compute the mask // 1. Compute the part owned by this worker (reversed order of batches to @@ -196,13 +204,31 @@ std::size_t run(const raft::handle_t& handle, CUML_LOG_DEBUG("--> Computing vertex degrees"); raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg"); - VertexDeg::run( - handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric); + VertexDeg::run(handle, + adj, + vd, + wght_sum, + x, + sample_weight, + eps, + N, + D, + algo_vd, + start_vertex_id, + n_points, + stream, + metric); raft::common::nvtx::pop_range(); CUML_LOG_DEBUG("--> Computing core point mask"); raft::common::nvtx::push_range("Trace::Dbscan::CorePoints"); - CorePoints::compute(handle, vd, core_pts, min_pts, start_vertex_id, n_points, stream); + if (wght_sum != nullptr) { + CorePoints::compute( + handle, wght_sum, core_pts, min_pts, start_vertex_id, n_points, stream); + } else { + CorePoints::compute( + handle, vd, core_pts, min_pts, start_vertex_id, n_points, stream); + } raft::common::nvtx::pop_range(); } // 2. Exchange with the other workers @@ -224,8 +250,20 @@ std::size_t run(const raft::handle_t& handle, if (i > 0) { CUML_LOG_DEBUG("--> Computing vertex degrees"); raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg"); - VertexDeg::run( - handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric); + VertexDeg::run(handle, + adj, + vd, + nullptr, + x, + nullptr, + eps, + N, + D, + algo_vd, + start_vertex_id, + n_points, + stream, + metric); raft::common::nvtx::pop_range(); } raft::update_host(&curradjlen, vd + n_points, 1, stream); diff --git a/cpp/src/dbscan/vertexdeg/algo.cuh b/cpp/src/dbscan/vertexdeg/algo.cuh index 817a977bbe..df6a248c89 100644 --- a/cpp/src/dbscan/vertexdeg/algo.cuh +++ b/cpp/src/dbscan/vertexdeg/algo.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -128,6 +128,24 @@ void launcher(const raft::handle_t& handle, return degree; }); RAFT_CUDA_TRY(cudaPeekAtLastError()); + + if (data.weight_sum != nullptr && data.sample_weight != nullptr) { + const value_t* sample_weight = data.sample_weight; + // Reduction of adj to compute the weighted vertex degrees + raft::linalg::coalescedReduction( + data.weight_sum, + data.adj, + data.N, + batch_size, + (value_t)0, + stream, + false, + [sample_weight] __device__(bool adj_ij, index_t j) { + return adj_ij ? sample_weight[j] : (value_t)0; + }, + raft::Sum()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } } } // namespace Algo diff --git a/cpp/src/dbscan/vertexdeg/pack.h b/cpp/src/dbscan/vertexdeg/pack.h index 1f3c551402..b7ef908e9d 100644 --- a/cpp/src/dbscan/vertexdeg/pack.h +++ b/cpp/src/dbscan/vertexdeg/pack.h @@ -28,10 +28,14 @@ struct Pack { * Hence, its length is one more than the number of points */ Index_* vd; + /** weighted vertex degree */ + Type* weight_sum; /** the adjacency matrix */ bool* adj; /** input dataset */ const Type* x; + /** weighted vertex degree */ + const Type* sample_weight; /** epsilon neighborhood thresholding param */ Type eps; /** number of points in the dataset */ diff --git a/cpp/src/dbscan/vertexdeg/precomputed.cuh b/cpp/src/dbscan/vertexdeg/precomputed.cuh index 7725d4b5d2..b2cf9fac2c 100644 --- a/cpp/src/dbscan/vertexdeg/precomputed.cuh +++ b/cpp/src/dbscan/vertexdeg/precomputed.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,24 @@ void launcher(const raft::handle_t& handle, return degree; }); RAFT_CUDA_TRY(cudaPeekAtLastError()); + + if (data.weight_sum != nullptr && data.sample_weight != nullptr) { + const value_t* sample_weight = data.sample_weight; + // Reduction of adj to compute the weighted vertex degrees + raft::linalg::coalescedReduction( + data.weight_sum, + data.adj, + data.N, + batch_size, + (value_t)0, + stream, + false, + [sample_weight] __device__(bool adj_ij, long_index_t j) { + return adj_ij ? sample_weight[j] : (value_t)0; + }, + raft::Sum()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } } } // namespace Precomputed diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index deded16783..6b7d586b8b 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,9 @@ template void run(const raft::handle_t& handle, bool* adj, Index_* vd, + Type_f* wght_sum, const Type_f* x, + const Type_f* sample_weight, Type_f eps, Index_ N, Index_ D, @@ -38,7 +40,7 @@ void run(const raft::handle_t& handle, cudaStream_t stream, raft::distance::DistanceType metric) { - Pack data = {vd, adj, x, eps, N, D}; + Pack data = {vd, wght_sum, adj, x, sample_weight, eps, N, D}; switch (algo) { case 0: ASSERT( diff --git a/cpp/test/sg/dbscan_test.cu b/cpp/test/sg/dbscan_test.cu index 6ac408dbe4..1d53df5315 100644 --- a/cpp/test/sg/dbscan_test.cu +++ b/cpp/test/sg/dbscan_test.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,6 +118,7 @@ class DbscanTest : public ::testing::TestWithParam> { params.metric, labels.data(), nullptr, + nullptr, params.max_bytes_per_batch); handle.sync_stream(stream); @@ -205,6 +206,7 @@ struct DBScan2DArrayInputs { T eps; int min_pts; const int* core_indices; // Expected core_indices + const T* sample_weight = nullptr; }; template @@ -221,9 +223,17 @@ class Dbscan2DSimple : public ::testing::TestWithParam> { rmm::device_uvector labels(params.n_row, stream); rmm::device_uvector labels_ref(params.n_out, stream); rmm::device_uvector core_sample_indices_d(params.n_row, stream); + rmm::device_uvector sample_weight_d(params.n_row, stream); raft::copy(inputs.data(), params.points, params.n_row * 2, stream); raft::copy(labels_ref.data(), params.out, params.n_out, stream); + + T* sample_weight = nullptr; + if (params.sample_weight != nullptr) { + raft::copy(sample_weight_d.data(), params.sample_weight, params.n_row, stream); + sample_weight = sample_weight_d.data(); + } + handle.sync_stream(stream); Dbscan::fit(handle, @@ -234,7 +244,8 @@ class Dbscan2DSimple : public ::testing::TestWithParam> { params.min_pts, raft::distance::L2SqrtUnexpanded, labels.data(), - core_sample_indices_d.data()); + core_sample_indices_d.data(), + sample_weight); handle.sync_stream(handle.get_stream()); @@ -275,6 +286,12 @@ const std::vector test2d1_f = {0, 0, 1, 0, 1, 1, 1, -1, 2, 0, 3, 0, 4, 0} const std::vector test2d1_d(test2d1_f.begin(), test2d1_f.end()); const std::vector test2d1_l = {0, 0, 0, 0, 0, -1, -1}; const std::vector test2d1c_l = {1, -1, -1, -1, -1, -1, -1}; +// modified for weighted samples --> wheights are shifted so that +// the rightmost point will be a core point as well +const std::vector test2d1w_f = {1, 2, 1, 1, -1, 1, 3}; +const std::vector test2d1w_d(test2d1w_f.begin(), test2d1w_f.end()); +const std::vector test2d1w_l = {0, 0, 0, 0, 0, 1, 1}; +const std::vector test2d1wc_l = {1, 6, -1, -1, -1, -1, -1}; // The input looks like a long two-barred (orhodox) cross or // two stars next to each other: @@ -287,6 +304,12 @@ const std::vector test2d2_f = {0, 0, 1, 0, 1, 1, 1, -1, 2, 0, 3, 0, 4, 0, const std::vector test2d2_d(test2d2_f.begin(), test2d2_f.end()); const std::vector test2d2_l = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}; const std::vector test2d2c_l = {1, 6, -1, -1, -1, -1, -1, -1, -1, -1}; +// modified for weighted samples --> wheight for the right center +// is negative that the whole right star is noise +const std::vector test2d2w_f = {1, 1, 1, 1, 1, 1, -2, 1, 1, 1}; +const std::vector test2d2w_d(test2d2w_f.begin(), test2d2w_f.end()); +const std::vector test2d2w_l = {0, 0, 0, 0, 0, -1, -1, -1, -1, -1}; +const std::vector test2d2wc_l = {1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; // The input looks like a two-barred (orhodox) cross or // two stars sharing a link: @@ -323,6 +346,10 @@ const std::vector test2d3_d(test2d3_f.begin(), test2d3_f.end()); const std::vector test2d3_l = {0, 0, 0, 0, 1, 1, 1, 1}; const std::vector test2d3c_l = {1, 4, -1, -1, -1, -1, -1, -1, -1}; +// ones for functional sample_weight testing +const std::vector test2d_ones_f = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +const std::vector test2d_ones_d(test2d_ones_f.begin(), test2d_ones_f.end()); + const std::vector> inputs2d_f = { {test2d1_f.data(), test2d1_l.data(), @@ -345,6 +372,48 @@ const std::vector> inputs2d_f = { 1.1f, 4, test2d3c_l.data()}, + // add dummy sample weights + {test2d1_f.data(), + test2d1_l.data(), + test2d1_f.size() / 2, + test2d1_l.size(), + 1.1f, + 4, + test2d1c_l.data(), + test2d_ones_f.data()}, + {test2d2_f.data(), + test2d2_l.data(), + test2d2_f.size() / 2, + test2d2_l.size(), + 1.1f, + 4, + test2d2c_l.data(), + test2d_ones_f.data()}, + {test2d3_f.data(), + test2d3_l.data(), + test2d3_f.size() / 2, + test2d3_l.size(), + 1.1f, + 4, + test2d3c_l.data(), + test2d_ones_f.data()}, + // special sample_weight cases + {test2d1_f.data(), + test2d1w_l.data(), + test2d1_f.size() / 2, + test2d1w_l.size(), + 1.1f, + 4, + test2d2wc_l.data(), + test2d2w_f.data()}, + {test2d2_f.data(), + test2d2w_l.data(), + test2d2_f.size() / 2, + test2d2w_l.size(), + 1.1f, + 4, + test2d2wc_l.data(), + test2d2w_f.data()}, }; const std::vector> inputs2d_d = { @@ -369,6 +438,48 @@ const std::vector> inputs2d_d = { 1.1, 4, test2d3c_l.data()}, + // add dummy sample weights + {test2d1_d.data(), + test2d1_l.data(), + test2d1_d.size() / 2, + test2d1_l.size(), + 1.1, + 4, + test2d1c_l.data(), + test2d_ones_d.data()}, + {test2d2_d.data(), + test2d2_l.data(), + test2d2_d.size() / 2, + test2d2_l.size(), + 1.1, + 4, + test2d2c_l.data(), + test2d_ones_d.data()}, + {test2d3_d.data(), + test2d3_l.data(), + test2d3_d.size() / 2, + test2d3_l.size(), + 1.1, + 4, + test2d3c_l.data(), + test2d_ones_d.data()}, + // special sample_weight cases + {test2d1_d.data(), + test2d1w_l.data(), + test2d1_d.size() / 2, + test2d1w_l.size(), + 1.1f, + 4, + test2d1wc_l.data(), + test2d1w_d.data()}, + {test2d2_d.data(), + test2d2w_l.data(), + test2d2_d.size() / 2, + test2d2w_l.size(), + 1.1f, + 4, + test2d2wc_l.data(), + test2d2w_d.data()}, }; typedef Dbscan2DSimple Dbscan2DSimple_F; diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index d1265c6458..a42b8cd2ca 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -47,6 +47,7 @@ cdef extern from "cuml/cluster/dbscan.hpp" \ DistanceType metric, int *labels, int *core_sample_indices, + float* sample_weight, size_t max_mbytes_per_batch, int verbosity, bool opg) except + @@ -60,6 +61,7 @@ cdef extern from "cuml/cluster/dbscan.hpp" \ DistanceType metric, int *labels, int *core_sample_indices, + double* sample_weight, size_t max_mbytes_per_batch, int verbosity, bool opg) except + @@ -73,6 +75,7 @@ cdef extern from "cuml/cluster/dbscan.hpp" \ DistanceType metric, int64_t *labels, int64_t *core_sample_indices, + float* sample_weight, size_t max_mbytes_per_batch, int verbosity, bool opg) except + @@ -86,6 +89,7 @@ cdef extern from "cuml/cluster/dbscan.hpp" \ DistanceType metric, int64_t *labels, int64_t *core_sample_indices, + double* sample_weight, size_t max_mbytes_per_batch, int verbosity, bool opg) except + @@ -235,7 +239,7 @@ class DBSCAN(Base, if self.max_mbytes_per_batch is None: self.max_mbytes_per_batch = 0 - def _fit(self, X, out_dtype, opg) -> "DBSCAN": + def _fit(self, X, out_dtype, opg, sample_weight) -> "DBSCAN": """ Protected auxiliary function for `fit`. Takes an additional parameter opg that is set to `False` for SG, `True` for OPG (multi-GPU) @@ -255,6 +259,13 @@ class DBSCAN(Base, cdef uintptr_t input_ptr = X_m.ptr + cdef uintptr_t sample_weight_ptr = NULL + if sample_weight is not None: + sample_weight_m, _, _, _ = \ + input_to_cuml_array(sample_weight, check_dtype=self.dtype, + check_rows=n_rows, check_cols=1) + sample_weight_ptr = sample_weight_m.ptr + cdef handle_t* handle_ = self.handle.getHandle() self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype, @@ -293,6 +304,7 @@ class DBSCAN(Base, metric, labels_ptr, core_sample_indices_ptr, + sample_weight_ptr, self.max_mbytes_per_batch, self.verbose, opg) @@ -306,6 +318,7 @@ class DBSCAN(Base, metric, labels_ptr, core_sample_indices_ptr, + sample_weight_ptr, self.max_mbytes_per_batch, self.verbose, opg) @@ -321,6 +334,7 @@ class DBSCAN(Base, metric, labels_ptr, core_sample_indices_ptr, + sample_weight_ptr, self.max_mbytes_per_batch, self.verbose, opg) @@ -334,6 +348,7 @@ class DBSCAN(Base, metric, labels_ptr, core_sample_indices_ptr, + sample_weight_ptr, self.max_mbytes_per_batch, self.verbose, opg) @@ -365,7 +380,7 @@ class DBSCAN(Base, return self @generate_docstring(skip_parameters_heading=True) - def fit(self, X, out_dtype="int32") -> "DBSCAN": + def fit(self, X, out_dtype="int32", sample_weight=None) -> "DBSCAN": """ Perform DBSCAN clustering from features. @@ -375,15 +390,21 @@ class DBSCAN(Base, default: "int32". Valid values are { "int32", np.int32, "int64", np.int64}. + sample_weight: array-like of shape (n_samples,), default=None + Weight of each sample, such that a sample with a weight of at + least min_samples is by itself a core sample; a sample with a + negative weight may inhibit its eps-neighbor from being core. + default: None (which is equivalent to weight 1 for all samples). + """ - return self._fit(X, out_dtype, False) + return self._fit(X, out_dtype, False, sample_weight) @generate_docstring(skip_parameters_heading=True, return_values={'name': 'preds', 'type': 'dense', 'description': 'Cluster labels', 'shape': '(n_samples, 1)'}) - def fit_predict(self, X, out_dtype="int32") -> CumlArray: + def fit_predict(self, X, out_dtype="int32", sample_weight=None) -> CumlArray: """ Performs clustering on X and returns cluster labels. @@ -393,8 +414,14 @@ class DBSCAN(Base, default: "int32". Valid values are { "int32", np.int32, "int64", np.int64}. + sample_weight: array-like of shape (n_samples,), default=None + Weight of each sample, such that a sample with a weight of at + least min_samples is by itself a core sample; a sample with a + negative weight may inhibit its eps-neighbor from being core. + default: None (which is equivalent to weight 1 for all samples). + """ - self.fit(X, out_dtype) + self.fit(X, out_dtype, sample_weight) return self.labels_ def get_param_names(self): diff --git a/python/cuml/cluster/dbscan_mg.pyx b/python/cuml/cluster/dbscan_mg.pyx index eb1fd06591..e6cad093ac 100644 --- a/python/cuml/cluster/dbscan_mg.pyx +++ b/python/cuml/cluster/dbscan_mg.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class DBSCANMG(DBSCAN): super().__init__(**kwargs) @generate_docstring(skip_parameters_heading=True) - def fit(self, X, out_dtype="int32") -> "DBSCANMG": + def fit(self, X, out_dtype="int32", sample_weight=None) -> "DBSCANMG": """ Perform DBSCAN clustering in a multi-node multi-GPU setting. Parameters @@ -42,4 +42,4 @@ class DBSCANMG(DBSCAN): default: "int32". Valid values are { "int32", np.int32, "int64", np.int64}. """ - return self._fit(X, out_dtype, True) + return self._fit(X, out_dtype, True, sample_weight)