Skip to content

Commit

Permalink
add sample_weight parameter to dbscan.fit (#5574)
Browse files Browse the repository at this point in the history
This PR adds the optional 'sample_weight' parameter to the dbscan fit function. 
Issue: #5556 

CC @tfeher

Authors:
  - Malte Förster (https://github.com/mfoerste4)
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5574
  • Loading branch information
mfoerste4 authored Sep 28, 2023
1 parent 3e5c8e9 commit a515624
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 33 deletions.
3 changes: 2 additions & 1 deletion cpp/bench/sg/dbscan.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,7 @@ class Dbscan : public BlobsFixture<D, int> {
raft::distance::L2SqrtUnexpanded,
this->data.y.data(),
this->core_sample_indices,
nullptr,
dParams.max_bytes_per_batch);
state.SetItemsProcessed(this->params.nrows * this->params.ncols);
});
Expand Down
3 changes: 2 additions & 1 deletion cpp/examples/dbscan/dbscan_example.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -205,6 +205,7 @@ int main(int argc, char* argv[])
raft::distance::L2SqrtUnexpanded,
d_labels,
nullptr,
nullptr,
max_bytes_per_batch,
false);
CUDA_RT_CALL(cudaMemcpyAsync(
Expand Down
10 changes: 9 additions & 1 deletion cpp/include/cuml/cluster/dbscan.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,10 @@ namespace Dbscan {
* indices of each core point. If the number of core points is less
* than n_rows, the right will be padded with -1. Setting this to
* NULL will prevent calculating the core sample indices
* @param[in] sample_weight (size n_rows) input array containing the
* weight of each sample to be taken instead of a plain sum to
* fulfill the min_pts criteria for core points.
* NULL will default to weights of 1 for all samples
* @param[in] max_bytes_per_batch the maximum number of megabytes to be used for
* each batch of the pairwise distance calculation. This enables the
* trade off between memory usage and algorithm execution time.
Expand All @@ -60,6 +64,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices = nullptr,
float* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand All @@ -72,6 +77,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices = nullptr,
double* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand All @@ -85,6 +91,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices = nullptr,
float* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand All @@ -97,6 +104,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices = nullptr,
double* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/dbscan/corepoints/compute.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,16 +28,16 @@ namespace CorePoints {
/**
* Compute the core points from the vertex degrees and min_pts criterion
* @param[in] handle cuML handle
* @param[in] vd Vertex degrees
* @param[in] vd Vertex degrees (optionally weighted)
* @param[out] mask Boolean core point mask
* @param[in] min_pts Core point criterion
* @param[in] start_vertex_id First point of the batch
* @param[in] batch_size Batch size
* @param[in] stream CUDA stream
*/
template <typename Index_ = int>
template <typename Values_ = int, typename Index_ = int>
void compute(const raft::handle_t& handle,
const Index_* vd,
const Values_* vd,
bool* mask,
Index_ min_pts,
Index_ start_vertex_id,
Expand All @@ -47,7 +47,7 @@ void compute(const raft::handle_t& handle,
auto counting = thrust::make_counting_iterator<Index_>(0);
thrust::for_each(
handle.get_thrust_policy(), counting, counting + batch_size, [=] __device__(Index_ idx) {
mask[idx + start_vertex_id] = vd[idx] >= min_pts;
mask[idx + start_vertex_id] = (Index_)vd[idx] >= min_pts;
});
}

Expand Down
14 changes: 13 additions & 1 deletion cpp/src/dbscan/dbscan.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices,
float* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -45,6 +46,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -58,6 +60,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -72,6 +75,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices,
double* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -86,6 +90,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -99,6 +104,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -113,6 +119,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices,
float* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -127,6 +134,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -140,6 +148,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -154,6 +163,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices,
double* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -168,6 +178,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -181,6 +192,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/dbscan/dbscan.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -104,6 +104,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
raft::distance::DistanceType metric,
Index_* labels,
Index_* core_sample_indices,
T* sample_weight,
size_t max_mbytes_per_batch,
cudaStream_t stream,
int verbosity)
Expand Down Expand Up @@ -177,6 +178,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
min_pts,
labels,
core_sample_indices,
sample_weight,
algo_vd,
algo_adj,
algo_ccl,
Expand All @@ -198,6 +200,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
min_pts,
labels,
core_sample_indices,
sample_weight,
algo_vd,
algo_adj,
algo_ccl,
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/dbscan/dbscan_api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,6 +47,7 @@ cumlError_t cumlSpDbscanFit(cumlHandle_t handle,
raft::distance::L2SqrtUnexpanded,
labels,
core_sample_indices,
NULL,
max_bytes_per_batch,
verbosity);
}
Expand Down Expand Up @@ -88,6 +89,7 @@ cumlError_t cumlDpDbscanFit(cumlHandle_t handle,
raft::distance::L2SqrtUnexpanded,
labels,
core_sample_indices,
NULL,
max_bytes_per_batch,
verbosity);
}
Expand Down
54 changes: 46 additions & 8 deletions cpp/src/dbscan/runner.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,6 +110,7 @@ std::size_t run(const raft::handle_t& handle,
Index_ min_pts,
Index_* labels,
Index_* core_indices,
const Type_f* sample_weight,
int algo_vd,
int algo_adj,
int algo_ccl,
Expand Down Expand Up @@ -146,6 +147,8 @@ std::size_t run(const raft::handle_t& handle,
std::size_t ex_scan_size = raft::alignTo<std::size_t>(sizeof(Index_) * batch_size, align);
std::size_t row_cnt_size = raft::alignTo<std::size_t>(sizeof(Index_) * batch_size, align);
std::size_t labels_size = raft::alignTo<std::size_t>(sizeof(Index_) * N, align);
std::size_t wght_sum_size =
sample_weight != nullptr ? raft::alignTo<std::size_t>(sizeof(Type_f) * batch_size, align) : 0;

Index_ MAX_LABEL = std::numeric_limits<Index_>::max();

Expand All @@ -157,8 +160,8 @@ std::size_t run(const raft::handle_t& handle,
(unsigned long)batch_size);

if (workspace == NULL) {
auto size =
adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size + 2 * labels_size;
auto size = adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size +
2 * labels_size + wght_sum_size;
return size;
}

Expand All @@ -183,6 +186,11 @@ std::size_t run(const raft::handle_t& handle,
temp += labels_size;
Index_* work_buffer = (Index_*)temp;
temp += labels_size;
Type_f* wght_sum = nullptr;
if (sample_weight != nullptr) {
wght_sum = (Type_f*)temp;
temp += wght_sum_size;
}

// Compute the mask
// 1. Compute the part owned by this worker (reversed order of batches to
Expand All @@ -196,13 +204,31 @@ std::size_t run(const raft::handle_t& handle,

CUML_LOG_DEBUG("--> Computing vertex degrees");
raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg");
VertexDeg::run<Type_f, Index_>(
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric);
VertexDeg::run<Type_f, Index_>(handle,
adj,
vd,
wght_sum,
x,
sample_weight,
eps,
N,
D,
algo_vd,
start_vertex_id,
n_points,
stream,
metric);
raft::common::nvtx::pop_range();

CUML_LOG_DEBUG("--> Computing core point mask");
raft::common::nvtx::push_range("Trace::Dbscan::CorePoints");
CorePoints::compute<Index_>(handle, vd, core_pts, min_pts, start_vertex_id, n_points, stream);
if (wght_sum != nullptr) {
CorePoints::compute<Type_f, Index_>(
handle, wght_sum, core_pts, min_pts, start_vertex_id, n_points, stream);
} else {
CorePoints::compute<Index_, Index_>(
handle, vd, core_pts, min_pts, start_vertex_id, n_points, stream);
}
raft::common::nvtx::pop_range();
}
// 2. Exchange with the other workers
Expand All @@ -224,8 +250,20 @@ std::size_t run(const raft::handle_t& handle,
if (i > 0) {
CUML_LOG_DEBUG("--> Computing vertex degrees");
raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg");
VertexDeg::run<Type_f, Index_>(
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric);
VertexDeg::run<Type_f, Index_>(handle,
adj,
vd,
nullptr,
x,
nullptr,
eps,
N,
D,
algo_vd,
start_vertex_id,
n_points,
stream,
metric);
raft::common::nvtx::pop_range();
}
raft::update_host(&curradjlen, vd + n_points, 1, stream);
Expand Down
Loading

0 comments on commit a515624

Please sign in to comment.