Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving low-level iterative solver primitives to raft::solver #923

Draft
wants to merge 42 commits into
base: branch-23.02
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
6873186
MOving gram matrix over to raft
cjnolet Oct 14, 2022
4df4df1
Adding specializations for gram matrix kernels
cjnolet Oct 14, 2022
c9e82cd
Commenting out rbf kernel instantiations for now.
cjnolet Oct 15, 2022
443130b
Fixing style after commenting RBC out
cjnolet Oct 15, 2022
231c3d9
Adding cudart_utils.hpp to init.cuh
cjnolet Oct 15, 2022
1bbd5f0
Style
cjnolet Oct 15, 2022
b663d3e
Fixing typo
cjnolet Oct 15, 2022
94bc17c
Fixing benchmark gramm
cjnolet Oct 15, 2022
f94cc7f
Fixing include
cjnolet Oct 15, 2022
ada42fd
Adding gram test to distances
cjnolet Oct 15, 2022
510cb6d
Adding gram.cu
cjnolet Oct 15, 2022
d1ab18e
Fixing import
cjnolet Oct 15, 2022
915d659
Adding missing curly brace
cjnolet Oct 16, 2022
86b72e6
CHanging namespace
cjnolet Oct 16, 2022
271abfd
Adding logger
cjnolet Oct 16, 2022
8a13f8d
Fixing style
cjnolet Oct 16, 2022
8d7ab1c
Fixing doc
cjnolet Oct 16, 2022
405f387
Pulling solver files over, starting to update qn solver public API
cjnolet Oct 17, 2022
8d2bd0b
Updates
cjnolet Oct 18, 2022
20a9734
Merge branch 'branch-22.12' into fea-2212-iterative_solvers
cjnolet Oct 20, 2022
879e85f
Fixing style
cjnolet Oct 20, 2022
66e1281
Checking in
cjnolet Oct 20, 2022
4131c0e
Exposing solver APIs. It needs some work but it's getting there.
cjnolet Oct 20, 2022
87d620f
Correcting spatial::knn docs to raft::neighbors
cjnolet Oct 20, 2022
febd1d4
Adding neighbors to index
cjnolet Oct 20, 2022
63e2e8b
Updating docs
cjnolet Oct 21, 2022
831c3d2
Making sure we call new cluster namespaced code from deprecated code to
cjnolet Oct 21, 2022
3c8bad1
Fixing style
cjnolet Oct 21, 2022
bb93fc1
Merge branch 'branch-22.12' into doc-2212-neighbors_docs
cjnolet Oct 21, 2022
31047e6
Deprecation warnings
cjnolet Oct 21, 2022
fd8899c
Fixing typo
cjnolet Oct 21, 2022
4968ce8
Fixing hierarchical compile error
cjnolet Oct 21, 2022
6ccf61f
Removing namespace conflict
cjnolet Oct 21, 2022
a36f26e
Many many updates to the docs, including a quick-start
cjnolet Oct 21, 2022
6acbb23
Merge branch 'branch-22.12' into doc-2212-neighbors_docs
cjnolet Oct 21, 2022
004af04
Fixing style
cjnolet Oct 21, 2022
b63f056
Merge branch 'doc-2212-neighbors_docs' into fea-2212-iterative_solvers
cjnolet Oct 21, 2022
b2fc56a
Merge branch 'branch-22.12' into fea-2212-iterative_solvers
cjnolet Oct 27, 2022
337dd02
Fixing style
cjnolet Oct 27, 2022
77a7c0b
A little more cleanup
cjnolet Oct 27, 2022
2adffb9
Merge branch 'branch-22.12' into fea-2212-iterative_solvers
cjnolet Jan 10, 2023
7c4f9de
Merge branch 'branch-23.02' into fea-2212-iterative_solvers
cjnolet Jan 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions cpp/include/raft/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void fit(handle_t const& handle,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
{
detail::kmeans_fit<DataT, IndexT>(handle, params, X, sample_weight, centroids, inertia, n_iter);
kmeans::fit<DataT, IndexT>(handle, params, X, sample_weight, centroids, inertia, n_iter);
}

/**
Expand Down Expand Up @@ -156,7 +156,7 @@ void predict(handle_t const& handle,
bool normalize_weight,
raft::host_scalar_view<DataT> inertia)
{
detail::kmeans_predict<DataT, IndexT>(
kmeans::predict<DataT, IndexT>(
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
}

Expand Down Expand Up @@ -219,7 +219,7 @@ void fit_predict(handle_t const& handle,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
{
detail::kmeans_fit_predict<DataT, IndexT>(
kmeans::fit_predict<DataT, IndexT>(
handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
}

Expand All @@ -245,7 +245,7 @@ void transform(const raft::handle_t& handle,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_matrix_view<DataT, IndexT> X_new)
{
detail::kmeans_transform<DataT, IndexT>(handle, params, X, centroids, X_new);
kmeans::transform<DataT, IndexT>(handle, params, X, centroids, X_new);
}

template <typename DataT, typename IndexT>
Expand All @@ -257,8 +257,7 @@ void transform(const raft::handle_t& handle,
IndexT n_features,
DataT* X_new)
{
detail::kmeans_transform<DataT, IndexT>(
handle, params, X, centroids, n_samples, n_features, X_new);
kmeans::transform<DataT, IndexT>(handle, params, X, centroids, n_samples, n_features, X_new);
}

/**
Expand Down Expand Up @@ -571,7 +570,7 @@ void fit_main(const raft::handle_t& handle,
handle, params, X, sample_weights, centroids, inertia, n_iter, workspace);
}

}; // end namespace raft::cluster::kmeans
}; // namespace raft::cluster::kmeans

namespace raft::cluster {

Expand Down
84 changes: 84 additions & 0 deletions cpp/include/raft/solver/coordinate_descent.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/solver/detail/cd.cuh>
#include <raft/solver/solver_types.hpp>

namespace raft::solver::coordinate_descent {

/**
* @brief Minimizes an objective function using the Coordinate Descent solver.
*
* Note: Currently only least squares loss is supported w/ optional lasso and elastic-net penalties:
* f(coef) = 1/2 * || b - Ax ||^2
* + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2
* + alpha * l1_ratio * ||coef||_1
*
* @param[in] handle: Reference of raft::handle_t
* @param[in] A: Input matrix in column-major format (size of n_rows, n_cols)
* @param[in] b: Input vector of labels (size of n_rows)
* @param[in] sample_weights: Optional input vector for sample weights (size n_rows)
* @param[out] x: Output vector of learned coefficients (size of n_cols)
* @param[out] intercept: Optional scalar to hold intercept if desired
*/
template <typename math_t, typename idx_t>
void minimize(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, col_major> A,
raft::device_vector_view<math_t, idx_t> b,
std::optional < raft::device_vector_view<math_t, idx_t> sample_weights,
raft::device_vector_view<math_t, idx_t> x,
std::optional<raft::device_scalar_view<math_t>> intercept,
cd_params<math_t>& params)
{
RAFT_EXPECTS(A.extent(0) == b.extent(0),
"Number of labels must match the number of rows in input matrix");

if (sample_weights.has_value()) {
RAFT_EXPECTS(A.extent(0) == sample_weights.value().extent(0),
"Number of sample weights must match number of rows in input matrix");
}

RAFT_EXPECTS(x.extent(0) == A.extent(1),
"Objective is linear. The number of coefficients must match the number features in "
"the input matrix");
RAFT_EXPECTS(lossFunct == loss_funct::SQRD_LOSS,
"Only squared loss is supported in the current implementation.");

math_t* intercept_ptr = intercept.has_value() ? intercept.value().data_handle() : nullptr;
math_t* sample_weight_ptr =
sample_weights.has_value() ? sample_weights.value().data_handle() : nullptr;

detail::cdFit(handle,
A.data_handle(),
A.extent(0),
A.extent(1),
b.data_handle(),
x.data_handle(),
intercept_ptr,
intercept.has_value(),
params.normalize,
params.epochs,
params.loss,
params.alpha,
params.l1_ratio,
params.shuffle,
params.tol,
sample_weight_ptr);
}
} // namespace raft::solver::coordinate_descent
Loading