diff --git a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp index 3eff920dd8..5772c5e86d 100644 --- a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp @@ -774,6 +774,66 @@ inline cusolverStatus_t cusolverDnpotrf(cusolverDnHandle_t handle, // NOLINT } /** @} */ +/** + * @defgroup potri cusolver potri operations: inverse of a matrix A using Cholesky + * @{ + */ +template +cusolverStatus_t cusolverDnpotri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, T* A, int lda, int* Lwork); +template <> +inline cusolverStatus_t cusolverDnpotri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* A, int lda, int* Lwork) +{ + return cusolverDnSpotri_bufferSize(handle, uplo, n, A, lda, Lwork); +} +template <> +inline cusolverStatus_t cusolverDnpotri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* A, int lda, int* Lwork) +{ + return cusolverDnDpotri_bufferSize(handle, uplo, n, A, lda, Lwork); +} + +template +cusolverStatus_t cusolverDnpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + T* A, + int lda, + T* Workspace, + int Lwork, + int* devInfo, + cudaStream_t stream); +template <> +inline cusolverStatus_t cusolverDnpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float* A, + int lda, + float* Workspace, + int Lwork, + int* devInfo, + cudaStream_t stream) +{ + RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream)); + return cusolverDnSpotri(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} +template <> +inline cusolverStatus_t cusolverDnpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double* A, + int lda, + double* Workspace, + int Lwork, + int* devInfo, + cudaStream_t stream) +{ + RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream)); + return cusolverDnDpotri(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} +/** @} */ + /** * @defgroup potrs cusolver potrs operations * @{ diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index 03e94a10b1..cbbd20d5c0 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -133,7 +133,7 @@ void eig_dc(raft::device_resources const& handle, raft::device_vector_view eig_vals) { RAFT_EXPECTS(in.size() == eig_vectors.size(), "Size mismatch between Input and Eigen Vectors"); - RAFT_EXPECTS(eig_vals.size() == in.extent(1), "Size mismatch between Input and Eigen Values"); + RAFT_EXPECTS(eig_vals.extent(0) == in.extent(1), "Size mismatch between Input and Eigen Values"); eigDC(handle, in.data_handle(), diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 96846003f6..df916d6d12 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -233,8 +233,8 @@ void gemv(raft::device_resources const& handle, * @tparam LayoutPolicyZ layout of Z * @param[in] handle raft handle * @param[in] A input raft::device_matrix_view of size (M, N) - * @param[in] x input raft::device_matrix_view of size (N, 1) if A is raft::col_major, else (M, 1) - * @param[out] y output raft::device_matrix_view of size (M, 1) if A is raft::col_major, else (N, 1) + * @param[in] x input raft::device_vector_view of size (N, 1) if A is raft::col_major, else (M, 1) + * @param[out] y output raft::device_vector_view of size (M, 1) if A is raft::col_major, else (N, 1) * @param[in] alpha optional raft::host_scalar_view or raft::device_scalar_view, default 1.0 * @param[in] beta optional raft::host_scalar_view or raft::device_scalar_view, default 0.0 */ diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index ef3a873d90..749acab573 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -197,6 +197,47 @@ void sliceMatrix(const m_t* in, slice<<>>(in, n_rows, n_cols, out, x1, y1, x2, y2); } +/** + * @brief Kernel for copying a small matrix inside of a bigger matrix with a + * size matches that slice + * @param src_d: input matrix + * @param m: number of rows of input matrix + * @param n: number of columns of input matrix + * @param dst_d: output matrix + * @param x1, y1: coordinate of the top-left point of the wanted area (0-based) + * @param x2, y2: coordinate of the bottom-right point of the wanted area + * (1-based) + */ +template +__global__ void slice_insert( + const m_t* src_d, idx_t n_rows, idx_t n_cols, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) +{ + idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; + idx_t dm = x2 - x1, dn = y2 - y1; + if (idx < dm * dn) { + idx_t i = idx % dm, j = idx / dm; + idx_t is = i + x1, js = j + y1; + dst_d[is + js * n_rows] = src_d[idx]; + } +} + +template +void sliceMatrix_insert(const m_t* in, + idx_t n_rows, + idx_t n_cols, + m_t* out, + idx_t x1, + idx_t y1, + idx_t x2, + idx_t y2, + cudaStream_t stream) +{ + // Slicing + dim3 block(64); + dim3 grid(((x2 - x1) * (y2 - y1) + block.x - 1) / block.x); + slice_insert<<>>(in, n_rows, n_cols, out, x1, y1, x2, y2); +} + /** * @brief Kernel for copying the upper triangular part of a matrix to another * @param src: input matrix with a size of mxn @@ -226,6 +267,60 @@ void copyUpperTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, c getUpperTriangular<<>>(src, dst, m, n, k); } +/** + * @brief Kernel for copying the lower triangular part of a matrix to another + * @param src: input matrix with a size of mxn + * @param dst: output matrix with a size of kxk + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param k: min(n_rows, n_cols) + */ +template +__global__ void getLowerTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, idx_t k) +{ + idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; + idx_t m = n_rows, n = n_cols; + if (idx < m * n) { + idx_t i = idx % m, j = idx / m; + if (i < k && j < k && j <= i) { dst[i + j * k] = src[idx]; } + } +} + +template +void copyLowerTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, cudaStream_t stream) +{ + idx_t m = n_rows, n = n_cols; + idx_t k = std::min(m, n); + dim3 block(64); + dim3 grid((m * n + block.x - 1) / block.x); + getLowerTriangular<<>>(src, dst, m, n, k); +} + +/** + * @brief Create a diagonal identity matrix + * @param matrix: matrix of size n_rows x n_cols + * @param n_rows: number of rows of the matrix + * @param n_cols: number of columns of the matrix + */ +template +__global__ void createEyeKernel(m_t* matrix, idx_t n_rows, idx_t n_cols) +{ + idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx < n_rows * n_cols) { + idx_t i = idx % n_rows, j = idx / n_rows; + matrix[idx] = m_t(j == i); + } +} + +template +void createEye(m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream) +{ + idx_t m = n_rows, n = n_cols; + dim3 block(64); + dim3 grid((m * n + block.x - 1) / block.x); + createEyeKernel<<>>(matrix, n_rows, n_cols); +} + /** * @brief Copy a vector to the diagonal of a matrix * @param vec: vector of length k = min(n_rows, n_cols) diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index ed2fb4d209..1c0234e0bd 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -69,6 +69,13 @@ void fill(raft::device_resources const& handle, linalg::map(handle, inout, raft::const_op{scalar}); } +template +void eye(const raft::handle_t& handle, + raft::device_matrix_view inout) +{ + detail::createEye(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); +} + /** @} */ // end of group matrix_init } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index bb92b2b86f..aa0dacf6eb 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -74,6 +74,40 @@ void slice(raft::device_resources const& handle, handle.get_stream()); } +/** + * @brief Insert a small matrix into a bigger matrix using a slice (in-place) + * @tparam m_t type of matrix elements + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft handle + * @param[in] in: input matrix (column-major) + * @param[out] out: output matrix (column-major) + * @param[in] coords: coordinates of the insertion slice + * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3}); + */ +template +void slice_insert(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + slice_coordinates coords) +{ + RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1"); + RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1"); + RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0"); + RAFT_EXPECTS(coords.row2 <= out.extent(0), "row2 must be <= number of rows in the output matrix"); + RAFT_EXPECTS(coords.col1 >= 0, "col1 must be >= 0"); + RAFT_EXPECTS(coords.col2 <= out.extent(1), + "col2 must be <= number of columns in the output matrix"); + + detail::sliceMatrix_insert(in.data_handle(), + out.extent(0), + out.extent(1), + out.data_handle(), + coords.row1, + coords.col1, + coords.row2, + coords.col2, + handle.get_stream()); +} /** @} */ // end group matrix_slice } // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index 3c60cc362f..4c77d5329b 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -44,6 +44,24 @@ void upper_triangular(raft::device_resources const& handle, src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); } +/** + * @brief Copy the lower triangular part of a matrix to another + * @param[in] handle: raft handle + * @param[in] src: input matrix with a size of n_rows x n_cols + * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) + */ +template +void lower_triangular(const raft::handle_t& handle, + raft::device_matrix_view src, + raft::device_matrix_view dst) +{ + auto k = std::min(src.extent(0), src.extent(1)); + RAFT_EXPECTS(k == dst.extent(0) && k == dst.extent(1), + "dst should be of size kxk, k = min(n_rows, n_cols)"); + detail::copyLowerTriangular( + src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); +} + /** @} */ // end group matrix_triangular } // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh new file mode 100644 index 0000000000..323e431e29 --- /dev/null +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -0,0 +1,1182 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +auto make_transpose_layout_view(raft::device_matrix_view mds) +{ + return raft::make_device_matrix_view(mds.data_handle(), mds.extent(1), mds.extent(0)); +} +template +auto make_transpose_layout_view(raft::device_matrix_view mds) +{ + return raft::make_device_matrix_view(mds.data_handle(), mds.extent(1), mds.extent(0)); +} + +namespace raft::sparse::solver::detail { + +/** + * @brief stucture that defines the reduction Lambda to find minimum between elements + */ +template +struct MaxOp { + HDI DataT operator()(DataT a, DataT b) { return maxPrim(a, b); } +}; + +template +struct isnan_test { + HDI int operator()(const DataT a) { return isnan(a); } +}; + +/** + * @tparam value_t floating point type used for elements + * @tparam index_t integer type used for indexing + * Assemble a matrix from a list of blocks + */ +template +void bmat(const raft::handle_t& handle, + raft::device_matrix_view out, + const std::vector>& ins, + index_t n_blocks) +{ + RAFT_EXPECTS(n_blocks * n_blocks == ins.size(), "inconsistent number of blocks"); + std::vector cumulative_row(n_blocks); + std::vector cumulative_col(n_blocks); + for (index_t i = 0; i < n_blocks; i++) { + for (index_t j = 0; j < n_blocks; j++) { + raft::matrix::slice_insert( + handle, + ins[j + i * n_blocks], + out, + raft::matrix::slice_coordinates(cumulative_row[j], + cumulative_col[i], + cumulative_row[j] + ins[j + i * n_blocks].extent(0), + cumulative_col[i] + ins[j + i * n_blocks].extent(1))); + cumulative_col[i] += ins[j + i * n_blocks].extent(0); + cumulative_row[j] += ins[j + i * n_blocks].extent(1); + } + } +} + +/* Modification of copyRows to reindex columns, col_major only + * On a 4x3 matrix, indices could be [0, 2] to select col 0 and 2 + */ +template +void selectCols(const m_t* in, + idx_t n_rows, + idx_t n_cols, + m_t* out, + const idx_array_t* indices, + idx_t n_cols_indices, + cudaStream_t stream) +{ + idx_t size = n_cols_indices * n_rows; + auto counting = thrust::make_counting_iterator(0); + + thrust::for_each(rmm::exec_policy(stream), counting, counting + size, [=] __device__(idx_t idx) { + idx_t row = idx % n_rows; + idx_t new_col = idx / n_rows; + idx_t old_col = indices[new_col]; + out[new_col * n_rows + row] = in[old_col * n_rows + row]; + }); +} + +template +void selectColsIf(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view mask, + raft::device_matrix_view out) +{ + auto stream = handle.get_stream(); + auto in_n_cols = in.extent(1); + auto out_n_cols = out.extent(1); + auto rangeVec = raft::make_device_vector(handle, in_n_cols); + raft::linalg::range(rangeVec.data_handle(), in_n_cols, stream); + raft::linalg::map( + handle, + raft::make_const_mdspan(mask), + raft::make_const_mdspan(rangeVec.view()), + rangeVec.view(), + [] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; }); + thrust::sort(rmm::exec_policy(stream), + rangeVec.data_handle(), + rangeVec.data_handle() + rangeVec.size(), + thrust::less()); + selectCols(in.data_handle(), + in.extent(0), + in.extent(1), + out.data_handle(), + rangeVec.data_handle() + rangeVec.size() - out_n_cols, + out_n_cols, + stream); +} + +/** + * Reverse if needed the eigenvalues/vectors and truncate the columns to fit eigVectorTrunc + */ +template +void truncEig( + const raft::handle_t& handle, + raft::device_matrix_view eigVectorin, + std::optional> eigVectorTrunc, + raft::device_vector_view eigLambda, + bool largest) +{ + // The eigenvalues are already sorted in ascending order with syevd + auto nrows = eigVectorin.extent(0); + auto ncols = eigVectorin.extent(1); + if (largest) { + raft::matrix::col_reverse(handle, eigVectorin); + raft::matrix::col_reverse( + handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0))); + } + if (eigVectorTrunc.has_value() && ncols > eigVectorTrunc->extent(1)) + raft::matrix::truncZeroOrigin(eigVectorin.data_handle(), + nrows, + eigVectorTrunc->data_handle(), + nrows, + eigVectorTrunc->extent(1), + handle.get_stream()); +} + +// C = A * B +template +void spmm(const raft::handle_t& handle, + raft::spectral::matrix::sparse_matrix_t A, + raft::device_matrix_view B, + raft::device_matrix_view C, + bool transpose_a = false, + bool transpose_b = false) +{ + auto stream = handle.get_stream(); + auto* A_values_ = const_cast(A.values_); + auto* A_row_offsets_ = const_cast(A.row_offsets_); + auto* A_col_indices_ = const_cast(A.col_indices_); + cusparseSpMatDescr_t sparse_A; + cusparseDnMatDescr_t dense_B; + cusparseDnMatDescr_t dense_C; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &sparse_A, A.nrows_, A.ncols_, A.nnz_, A_row_offsets_, A_col_indices_, A_values_)); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &dense_B, B.extent(0), B.extent(1), B.extent(0), B.data_handle(), CUSPARSE_ORDER_COL)); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &dense_C, C.extent(0), C.extent(1), C.extent(0), C.data_handle(), CUSPARSE_ORDER_COL)); + // a * b + value_t alpha = 1; + value_t beta = 0; + size_t buff_size = 0; + auto opA = transpose_a ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto opB = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opA, + opB, + &alpha, + sparse_A, + dense_B, + &beta, + dense_C, + CUSPARSE_SPMM_ALG_DEFAULT, + &buff_size, + stream); + rmm::device_uvector dev_buffer(buff_size / sizeof(value_t), stream); + raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opA, + opB, + &alpha, + sparse_A, + dense_B, + &beta, + dense_C, + CUSPARSE_SPMM_ALG_DEFAULT, + dev_buffer.data(), + stream); + + cusparseDestroySpMat(sparse_A); + cusparseDestroyDnMat(dense_B); + cusparseDestroyDnMat(dense_C); +} + +/** + * Solve the linear equation A x = b, given the Cholesky factorization of A + * The operation is in-place, i.e. matrix X overwrites matrix B. + */ +template +void cho_solve(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + bool lower = true) +{ + auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = handle.get_stream(); + auto lda = A.extent(0); + auto dim = A.extent(0); + cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + + rmm::device_uvector info(1, stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrs(handle.get_cusolver_dn_handle(), + uplo, + dim, + B.extent(1), + A.data_handle(), + lda, + B.data_handle(), + dim, + info.data(), + stream)); +} + +template +bool cholesky(const raft::handle_t& handle, + raft::device_matrix_view P, + bool lower = true) +{ + auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = handle.get_stream(); + int Lwork = 0; + auto lda = P.extent(0); + auto dim = P.extent(0); + cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + + auto P_copy = + raft::make_device_matrix(handle, P.extent(0), P.extent(1)); + raft::copy(P_copy.data_handle(), P.data_handle(), P.size(), stream); + + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize( + handle.get_cusolver_dn_handle(), uplo, dim, P_copy.data_handle(), lda, &Lwork)); + + rmm::device_uvector workspace_decomp(Lwork / sizeof(value_t), stream); + rmm::device_uvector info(1, stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(handle.get_cusolver_dn_handle(), + uplo, + dim, + P_copy.data_handle(), + lda, + workspace_decomp.data(), + Lwork, + info.data(), + stream)); + int info_h = 0; + raft::update_host(&info_h, info.data(), 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT(info_h == 0, "lobpcg: error in potrf, info=%d | expected=0", info_h); + + int h_hasnan = thrust::transform_reduce(thrust_exec_policy, + P_copy.data_handle(), + P_copy.data_handle() + P_copy.size(), + isnan_test(), + 0, + thrust::plus()); + + if (h_hasnan != 0) // "lobpcg: error in cholesky, NaN in outputs" + return false; + + raft::matrix::fill(handle, P, value_t(0)); + if (lower) { + raft::matrix::lower_triangular(handle, raft::make_const_mdspan(P_copy.view()), P); + } else { + raft::matrix::upper_triangular(handle, raft::make_const_mdspan(P_copy.view()), P); + } + return true; +} + +template +void inverse(const raft::handle_t& handle, + raft::device_matrix_view P, + raft::device_matrix_view Pinv, + bool transposeP = false) +{ + auto stream = handle.get_stream(); + int Lwork = 0; + auto lda = P.extent(0); + auto dim = P.extent(0); + int info_h = 0; + cublasOperation_t trans = transposeP ? CUBLAS_OP_T : CUBLAS_OP_N; + raft::matrix::eye(handle, Pinv); + + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf_bufferSize( + handle.get_cusolver_dn_handle(), dim, dim, P.data_handle(), lda, &Lwork)); + + auto P_copy = raft::make_device_matrix(handle, P.extent(0), P.extent(1)); + raft::copy(P_copy.data_handle(), P.data_handle(), P.size(), stream); + rmm::device_uvector workspace_decomp(Lwork, stream); + rmm::device_uvector info(1, stream); + auto ipiv = raft::make_device_vector(handle, dim); + + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf(handle.get_cusolver_dn_handle(), + dim, + dim, + P_copy.data_handle(), + lda, + workspace_decomp.data(), + ipiv.data_handle(), + info.data(), + stream)); + + raft::update_host(&info_h, info.data(), 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT(info_h == 0, "lobpcg: error in getrf, info=%d | expected=0", info_h); + + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrs(handle.get_cusolver_dn_handle(), + trans, + dim, + dim, + P_copy.data_handle(), + lda, + ipiv.data_handle(), + Pinv.data_handle(), + lda, + info.data(), + stream)); + + raft::update_host(&info_h, info.data(), 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT(info_h == 0, "lobpcg: error in getrs, info=%d | expected=0", info_h); +} + +template +void apply_constraints(const raft::handle_t& handle, + raft::device_matrix_view V, + raft::device_matrix_view YBY, + raft::device_matrix_view BY, + raft::device_matrix_view Y) +{ + auto stream = handle.get_stream(); + auto YBY_copy = raft::make_device_matrix( + handle, YBY.extent(0), YBY.extent(1)); + raft::copy(YBY_copy.data_handle(), YBY.data_handle(), YBY.size(), stream); + // TODO: Use mdspan gemm with row-major to transpose + auto YBV = + raft::make_device_matrix(handle, BY.extent(1), V.extent(1)); + value_t zero = 0; + value_t one = 1; + raft::linalg::gemm(handle, + true, + false, + YBV.extent(0), + YBV.extent(1), + BY.extent(0), + &one, + BY.data_handle(), + BY.extent(0), + V.data_handle(), + V.extent(0), + &zero, + YBV.data_handle(), + YBV.extent(0), + stream); + + cholesky(handle, YBY_copy.view()); + cho_solve(handle, raft::make_const_mdspan(YBY_copy.view()), YBV.view()); + auto BV = + raft::make_device_matrix(handle, V.extent(0), YBV.extent(1)); + raft::linalg::gemm(handle, Y, YBV.view(), BV.view()); + raft::linalg::subtract(handle, raft::make_const_mdspan(Y), raft::make_const_mdspan(BV.view()), Y); +} + +/** + * Helper function for converting a generalized eigenvalue problem + * A(X) = lambda(B(X)) to standard eigen value problem using cholesky + * transformation + */ +template +bool eigh(const raft::handle_t& handle, + raft::device_matrix_view A, + std::optional> B_opt, + raft::device_matrix_view eigVecs, + raft::device_vector_view eigVals) +{ + auto dim = A.extent(0); + auto AT = raft::make_device_matrix(handle, dim, dim); + raft::linalg::transpose(handle, A, AT.view()); + if (!B_opt.has_value()) { + raft::linalg::eig_dc(handle, raft::make_const_mdspan(AT.view()), eigVecs, eigVals); + return true; + } + auto RTi = raft::make_device_matrix(handle, dim, dim); + auto Ri = raft::make_device_matrix(handle, dim, dim); + auto R = raft::make_device_matrix(handle, dim, dim); + auto F = raft::make_device_matrix(handle, dim, dim); + auto B = B_opt.value(); + bool cho_success = cholesky(handle, B, false); + raft::linalg::transpose(handle, B, R.view()); + + inverse(handle, R.view(), Ri.view(), true); + inverse(handle, R.view(), RTi.view(), false); + + // Reuse the memory of matrix + auto& ARi = B; + auto& Fvecs = R; + raft::linalg::gemm(handle, A, Ri.view(), ARi); + raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); + + auto FT = raft::make_device_matrix(handle, dim, dim); + raft::linalg::transpose(handle, F.view(), FT.view()); + + raft::linalg::eig_dc(handle, raft::make_const_mdspan(FT.view()), Fvecs.view(), eigVals); + raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); + return cho_success; +} + +/** + * B-orthonormalize the given block vector using Cholesky + * + * @tparam value_t floating point type used for elements + * @tparam index_t integer type used for indexing + * @param[in] handle: raft handle + * @param[inout] V: dense matrix to normalize + * @param[inout] BV: dense matrix. Use with parameter `bv_is_empty`. + * @param[in] B_opt: optional sparse matrix for normalization + * @param[out] VBV_opt: optional dense matrix containing inverse matrix (shape v[1] * v[1]) + * @param[out] V_max_opt: optional vector containing normalization of V (shape v[1]) + * @param[in] bv_is_empty: True if BV is used as input + * @return success status + */ +template +bool b_orthonormalize( + const raft::handle_t& handle, + raft::device_matrix_view V, + raft::device_matrix_view BV, + std::optional> B_opt = std::nullopt, + std::optional> VBV_opt = std::nullopt, + std::optional> V_max_opt = std::nullopt, + bool bv_is_empty = true) +{ + auto stream = handle.get_stream(); + auto V_max_buffer = rmm::device_uvector(0, stream); + value_t* V_max_ptr = nullptr; + if (!V_max_opt) { // allocate normalization buffer + V_max_buffer.resize(V.extent(1), stream); + V_max_ptr = V_max_buffer.data(); + } else { + V_max_ptr = V_max_opt.value().data_handle(); + } + auto V_max = raft::make_device_vector_view(V_max_ptr, V.extent(1)); + + /*raft::linalg::reduce(handle, + raft::make_device_matrix_view( + V.data_handle(), V.extent(0), V.extent(1)), + V_max, + value_t(0), + raft::linalg::Apply::ALONG_ROWS, + false, + raft::identity_op(), + MaxOp());*/ + // Coalesced reduction + raft::linalg::reduce(V_max.data_handle(), + V.data_handle(), + V.extent(1), + V.extent(0), + value_t(0), + false, + false, + handle.get_stream(), + false, + raft::identity_op(), + MaxOp()); + raft::linalg::binary_div_skip_zero(handle, V, raft::make_const_mdspan(V_max), raft::linalg::Apply::ALONG_ROWS); + + if (!bv_is_empty) { + raft::linalg::binary_div_skip_zero(handle, BV, raft::make_const_mdspan(V_max), raft::linalg::Apply::ALONG_ROWS); + } else { + if (B_opt) + spmm(handle, B_opt.value(), V, BV); + else + raft::copy(BV.data_handle(), V.data_handle(), V.size(), stream); + } + auto VBV_buffer = rmm::device_uvector(0, stream); + value_t* VBV_ptr = nullptr; + if (!VBV_opt) { // allocate normalization buffer + VBV_buffer.resize(V.extent(1) * V.extent(1), stream); + VBV_ptr = VBV_buffer.data(); + } else { + VBV_ptr = VBV_opt.value().data_handle(); + } + auto VBV = raft::make_device_matrix_view( + VBV_ptr, V.extent(1), V.extent(1)); + auto VBVBuffer = raft::make_device_matrix( + handle, VBV.extent(0), VBV.extent(1)); + auto VT = make_transpose_layout_view(V); + + raft::linalg::gemm(handle, VT, BV, VBV); + bool cholesky_success = cholesky(handle, VBV, false); + if (!cholesky_success) { return cholesky_success; } + + inverse(handle, VBV, VBVBuffer.view()); + raft::copy(VBV.data_handle(), VBVBuffer.data_handle(), VBV.size(), stream); + raft::linalg::gemm(handle, V, VBV, V); + if (B_opt) raft::linalg::gemm(handle, BV, VBV, BV); + return true; +} + +template +void lobpcg( + const raft::handle_t& handle, + // IN + raft::spectral::matrix::sparse_matrix_t A, // shape=(n,n) + raft::device_matrix_view X, // shape=(n,k) IN OUT Eigvectors + raft::device_vector_view W, // shape=(k) OUT Eigvals + std::optional> B_opt, // shape=(n,n) + std::optional> M_opt, // shape=(n,n) + std::optional> Y_opt, // Constraint + // matrix shape=(n,Y) + value_t tol = 0, + std::int32_t max_iter = 20, + bool largest = true, + int verbosityLevel = 0) +{ + cudaStream_t stream = handle.get_stream(); + auto thrust_exec_policy = handle.get_thrust_policy(); + // auto size_y = 0; + // if (Y_opt.has_value()) size_y = Y_opt.value().extent(1); + auto n = X.extent(0); + auto size_x = X.extent(1); + + /* TODO: DENSE SOLUTION + if ((n - size_y) < (5 * size_x)) { + return; + } */ + if (tol <= 0) { tol = raft::mySqrt(1e-15) * n; } + // Apply constraints to X + /* + auto matrix_BY = raft::make_device_matrix(handle, n, size_y); + if (Y_opt.has_value()) + { + if (B_opt.has_value()) + { + auto B = B_opt.value(); + spmm(handle, Y_opt.value(), B, matrix_BY.view(), false, false); + // TODO + } else { + raft::copy(matrix_BY.data_handle(), Y_opt.value().data_handle(), n * size_y, + handle.get_stream()); + } + // GramYBY + // ApplyConstraints + }*/ + auto BX = raft::make_device_matrix(handle, n, size_x); + auto BXView = BX.view(); + b_orthonormalize(handle, X, BXView, B_opt); + // Compute the initial Ritz vectors: solve the eigenproblem. + auto AX = raft::make_device_matrix(handle, n, size_x); + spmm(handle, A, X, AX.view()); + auto gramXAX = + raft::make_device_matrix(handle, size_x, size_x); + auto XTRowView = make_transpose_layout_view(X); + raft::linalg::gemm(handle, + XTRowView, + AX.view(), gramXAX.view()); + auto eigVectorBuffer = rmm::device_uvector(size_x * size_x, stream); // rmm because of resize + auto eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), size_x, size_x); + auto eigLambda = raft::make_device_vector(handle, size_x); + std::optional> empty_matrix_opt = std::nullopt; + eigh(handle, gramXAX.view(), empty_matrix_opt, eigVectorView, eigLambda.view()); + + truncEig(handle, eigVectorView, empty_matrix_opt, eigLambda.view(), largest); + // Slice not needed for first eigh + // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0, + // eigVectorFull.extent(0), size_x)); + + raft::linalg::gemm(handle, X, eigVectorView, X); + raft::linalg::gemm(handle, AX.view(), eigVectorView, AX.view()); + if (B_opt) raft::linalg::gemm(handle, BXView, eigVectorView, BXView); + + // Active index set + // TODO: use uint8_t + auto active_mask = raft::make_device_vector(handle, size_x); + auto previousBlockSize = size_x; + + auto ident = rmm::device_uvector(size_x * size_x, stream); + auto identView = raft::make_device_matrix_view( + ident.data(), size_x, size_x); + raft::matrix::eye(handle, identView); + auto identSizeX = raft::make_device_matrix( + handle, size_x, size_x); + raft::matrix::eye(handle, identSizeX.view()); + + auto Pbuffer = rmm::device_uvector(0, stream); + auto APbuffer = rmm::device_uvector(0, stream); + auto BPbuffer = rmm::device_uvector(0, stream); + auto PView = + raft::make_device_matrix_view(Pbuffer.data(), 0, 0); + auto APView = + raft::make_device_matrix_view(APbuffer.data(), 0, 0); + auto BPView = + raft::make_device_matrix_view(BPbuffer.data(), 0, 0); + auto activePbuffer = rmm::device_uvector(0, stream); + auto activeAPbuffer = rmm::device_uvector(0, stream); + auto activeBPbuffer = rmm::device_uvector(0, stream); + auto activePView = + raft::make_device_matrix_view(activePbuffer.data(), 0, 0); + auto activeAPView = + raft::make_device_matrix_view(activeAPbuffer.data(), 0, 0); + auto activeBPView = + raft::make_device_matrix_view(activeBPbuffer.data(), 0, 0); + auto R = raft::make_device_matrix(handle, n, size_x); + + auto aux = raft::make_device_matrix( + handle, n, size_x); + //auto aux_sum = raft::make_device_vector(handle, size_x); + auto residual_norms = raft::make_device_vector(handle, size_x); + std::int32_t iteration_number = -1; + bool restart = true; + bool explicitGramFlag = false; + while (iteration_number < max_iter + 1) { + iteration_number += 1; + if (B_opt) { + raft::matrix::copy(handle, raft::make_const_mdspan(BXView), aux.view()); + } else { + raft::matrix::copy(handle, raft::make_const_mdspan(X), aux.view()); + } + raft::linalg::binary_mult_skip_zero(handle, + aux.view(), + raft::make_const_mdspan(eigLambda.view()), + raft::linalg::Apply::ALONG_ROWS); + + raft::linalg::subtract( + handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); + + raft::linalg::norm(handle, make_const_mdspan(R.view()), residual_norms.view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_COLUMNS, raft::sqrt_op()); + + // cupy where & active_mask + raft::linalg::unary_op(handle, + raft::make_const_mdspan(residual_norms.view()), + active_mask.view(), + [tol] __device__(value_t rn) { return rn > tol; }); + if (verbosityLevel > 2) { + print_device_vector("active_mask", active_mask.data_handle(), active_mask.size(), std::cout); + } + index_t currentBlockSize = thrust::count_if(thrust::cuda::par.on(stream), + active_mask.data_handle(), + active_mask.data_handle() + active_mask.size(), + [] __device__(value_t v) {return v > 0; }); + handle.sync_stream(); + if (currentBlockSize != previousBlockSize) { + previousBlockSize = currentBlockSize; + ident.resize(currentBlockSize * currentBlockSize, stream); + identView = raft::make_device_matrix_view( + ident.data(), currentBlockSize, currentBlockSize); + raft::matrix::eye(handle, identView); + } + + if (currentBlockSize == 0) break; + if (verbosityLevel > 0) { + printf("Iteration: %i\n", iteration_number); + printf("current block size: %d\n", currentBlockSize); + raft::matrix::print_separators ps{}; + printf("lambda:\n"); + raft::matrix::print(handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0)), ps); + printf("residual norms:\n"); + raft::matrix::print(handle, raft::make_device_matrix_view(residual_norms.data_handle(), 1, residual_norms.extent(0)), ps); + if (verbosityLevel > 10) { + printf("eigBlockVector:\n"); + raft::matrix::print(handle, make_const_mdspan(eigVectorView), ps); + + } + } + raft::make_device_matrix(handle, n, currentBlockSize); + + selectColsIf(handle, R.view(), active_mask.view(), activeR.view()); + + if (iteration_number > 0) { + activePbuffer.resize(n * currentBlockSize, stream); + activeAPbuffer.resize(n * currentBlockSize, stream); + activeBPbuffer.resize(n * currentBlockSize, stream); + activePView = raft::make_device_matrix_view(activePbuffer.data(), n, currentBlockSize); + activeAPView = raft::make_device_matrix_view(activeAPbuffer.data(), n, currentBlockSize); + selectColsIf(handle, PView, active_mask.view(), activePView); + selectColsIf(handle, APView, active_mask.view(), activeAPView); + if (B_opt.has_value()) { + activeBPView = raft::make_device_matrix_view(activeBPbuffer.data(), n, currentBlockSize); + selectColsIf(handle, BPView, active_mask.view(), activeBPView); + } + } + if (M_opt.has_value()) { + // Apply preconditioner T to the active residuals. + auto MRtemp = raft::make_device_matrix( + handle, R.extent(0), currentBlockSize); + spmm(handle, M_opt.value(), activeR.view(), MRtemp.view()); + raft::copy(activeR.data_handle(), MRtemp.data_handle(), MRtemp.size(), stream); + } + // Apply constraints to the preconditioned residuals. + if (Y_opt.has_value()) { + // TODO Constraint + // apply_constraints(handle, X, gramYBY.view(), BY.view(), Y_opt.value()); + } + // B-orthogonalize the preconditioned residuals to X. + if (B_opt.has_value()) { + auto BXTR = raft::make_device_matrix( + handle, BX.extent(1), activeR.extent(1)); + auto XBXTR = raft::make_device_matrix( + handle, X.extent(0), BXTR.extent(1)); + + raft::linalg::gemm(handle, + make_transpose_layout_view(BX.view()), + activeR.view(), BXTR.view()); + raft::linalg::gemm(handle, X, BXTR.view(), XBXTR.view()); + raft::linalg::subtract(handle, + raft::make_const_mdspan(activeR.view()), + raft::make_const_mdspan(XBXTR.view()), + activeR.view()); + } else { + auto XTR = raft::make_device_matrix( + handle, X.extent(1), activeR.extent(1)); + auto XXTR = raft::make_device_matrix( + handle, X.extent(0), XTR.extent(1)); + raft::linalg::gemm(handle, XTRowView, activeR.view(), XTR.view()); + raft::linalg::gemm(handle, X, XTR.view(), XXTR.view()); + raft::linalg::subtract(handle, + raft::make_const_mdspan(activeR.view()), + raft::make_const_mdspan(XXTR.view()), + activeR.view()); + } + // B-orthonormalize the preconditioned residuals. + auto activeBR = raft::make_device_matrix( + handle, activeR.extent(0), activeR.extent(1)); + auto activeBRView = activeBR.view(); + b_orthonormalize(handle, activeR.view(), activeBRView, B_opt); + + auto activeAR = + raft::make_device_matrix(handle, n, activeR.extent(1)); + spmm(handle, A, activeR.view(), activeAR.view()); + + if (iteration_number > 0) { + auto invR = raft::make_device_matrix( + handle, activePView.extent(1), activePView.extent(1)); + auto normal = raft::make_device_vector(handle, activePView.extent(1)); + bool b_orth_success = true; + if (!B_opt.has_value()) { + auto BP = raft::make_device_matrix( + handle, activePView.extent(0), activePView.extent(1)); + b_orth_success = b_orthonormalize(handle, + activePView, + BP.view(), + B_opt, + std::make_optional(invR.view()), + std::make_optional(normal.view())); + } else { + b_orth_success = b_orthonormalize(handle, + activePView, + activeBPView, + B_opt, + std::make_optional(invR.view()), + std::make_optional(normal.view()), + false); + } + if (!b_orth_success) { + restart = true; + } else { + raft::linalg::binary_div_skip_zero(handle, + activeAPView, + raft::make_const_mdspan(normal.view()), + raft::linalg::Apply::ALONG_ROWS); + raft::linalg::gemm(handle, activeAPView, invR.view(), activeAPView); + restart = false; + } + + // Perform the Rayleigh Ritz Procedure: + // Compute symmetric Gram matrices: + value_t myeps = 1; // TODO: std::is_same_t ? 1e-4 : 1e-8; + if (!explicitGramFlag) { + value_t* residual_norms_max_elem = + thrust::max_element(thrust_exec_policy, + residual_norms.data_handle(), + residual_norms.data_handle() + residual_norms.size()); + value_t residual_norms_max = 0; + raft::copy(&residual_norms_max, residual_norms_max_elem, 1, stream); + handle.sync_stream(); + explicitGramFlag = residual_norms_max > myeps; + } + + if (!B_opt.has_value()) { + // Shared memory assignments to simplify the code + BXView = X; + activeBRView = activeR.view(); + if (!restart) + activeBPView = activePView; + } + } + // Common submatrices + auto gramXAR = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRAR = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBX = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRBR = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBR = + raft::make_device_matrix(handle, size_x, currentBlockSize); + raft::linalg::gemm(handle, + XTRowView, + activeAR.view(), + gramXAR.view()); + + raft::linalg::gemm( + handle, + make_transpose_layout_view(activeR.view()), + activeAR.view(), + gramRAR.view()); + + auto device_half = raft::make_device_scalar(handle, 0.5); + if (explicitGramFlag) { + raft::linalg::gemm( + handle, + make_transpose_layout_view(gramRAR.view()), + identView, + gramRAR.view(), + std::make_optional(device_half.view()), + std::make_optional(device_half.view())); + raft::linalg::gemm(handle, + XTRowView, + AX.view(), + gramXAX.view()); + raft::linalg::gemm( + handle, + make_transpose_layout_view(gramXAX.view()), + identView, + gramXAX.view(), + std::make_optional(device_half.view()), + std::make_optional(device_half.view())); + + raft::linalg::gemm(handle, + XTRowView, + BX.view(), + gramXBX.view()); + raft::linalg::gemm( + handle, + make_transpose_layout_view(activeR.view()), + activeBRView, + gramRBR.view()); + raft::linalg::gemm(handle, + XTRowView, + activeBRView, + gramXBR.view()); + } else { + raft::matrix::fill(handle, gramXAX.view(), value_t(0)); + raft::matrix::set_diagonal(handle, make_const_mdspan(eigLambda.view()), gramXAX.view()); + + raft::matrix::eye(handle, gramXBX.view()); + raft::matrix::eye(handle, gramRBR.view()); + raft::matrix::fill(handle, gramXBR.view(), value_t(0)); + } + auto gramDim = gramXAX.extent(1) + gramXAR.extent(1) + currentBlockSize; + auto gramA = raft::make_device_matrix(handle, gramDim, gramDim); + auto gramB = raft::make_device_matrix(handle, gramDim, gramDim); + auto gramAView = gramA.view(); + auto gramBView = gramB.view(); + auto eigLambdaTemp = raft::make_device_vector(handle, gramDim); + auto eigVectorTemp = + raft::make_device_matrix(handle, gramDim, gramDim); + auto eigLambdaTempView = eigLambdaTemp.view(); + auto eigVectorTempView = eigVectorTemp.view(); + auto gramXAP = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRAP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramPAP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBP = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRBP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramPBP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + // create transpose mat + auto gramXAPT = raft::make_device_matrix( + handle, gramXAP.extent(1), gramXAP.extent(0)); + auto gramXART = raft::make_device_matrix( + handle, gramXAR.extent(1), gramXAR.extent(0)); + auto gramRAPT = raft::make_device_matrix( + handle, gramRAP.extent(1), gramRAP.extent(0)); + auto gramXBPT = raft::make_device_matrix( + handle, gramXBP.extent(1), gramXBP.extent(0)); + auto gramXBRT = raft::make_device_matrix( + handle, gramXBR.extent(1), gramXBR.extent(0)); + auto gramRBPT = raft::make_device_matrix( + handle, gramRBP.extent(1), gramRBP.extent(0)); + raft::linalg::transpose(handle, gramXAR.view(), gramXART.view()); + raft::linalg::transpose(handle, gramXBR.view(), gramXBRT.view()); + + if (!restart) { + raft::linalg::gemm(handle, + XTRowView, + activeAPView, + gramXAP.view()); + raft::linalg::gemm( + handle, + make_transpose_layout_view(activeR.view()), + activeAPView, + gramRAP.view()); + raft::linalg::gemm(handle, + make_transpose_layout_view(activePView), + activeAPView, + gramPAP.view()); + raft::linalg::gemm(handle, + XTRowView, + activeBPView, + gramXBP.view()); + raft::linalg::gemm( + handle, + make_transpose_layout_view(activeR.view()), + activeBPView, + gramRBP.view()); + + if (explicitGramFlag) { + raft::linalg::gemm( + handle, + make_transpose_layout_view(gramPAP.view()), + identView, + gramPAP.view(), + std::make_optional(device_half.view()), + std::make_optional(device_half.view())); + raft::linalg::gemm(handle, + make_transpose_layout_view(activePView), + activeBPView, + gramPBP.view()); + } else { + raft::matrix::eye(handle, gramPBP.view()); + } + raft::linalg::transpose(handle, gramXAP.view(), gramXAPT.view()); + raft::linalg::transpose(handle, gramRAP.view(), gramRAPT.view()); + raft::linalg::transpose(handle, gramXBP.view(), gramXBPT.view()); + raft::linalg::transpose(handle, gramRBP.view(), gramRBPT.view()); + + std::vector> A_blocks = { + gramXAX.view(), gramXAR.view(), gramXAP.view(), gramXART.view(), gramRAR.view(), gramRAP.view(), gramXAPT.view(), gramRAPT.view(), gramPAP.view()}; + std::vector> B_blocks = { + gramXBX.view(), gramXBR.view(), gramXBP.view(), gramXBRT.view(), gramRBR.view(), gramRBP.view(), gramXBPT.view(), gramRBPT.view(), gramPBP.view()}; + gramAView = + raft::make_device_matrix_view(gramA.data_handle(), n, n); + gramBView = + raft::make_device_matrix_view(gramB.data_handle(), n, n); + + bmat(handle, gramAView, A_blocks, 3); + bmat(handle, gramBView, B_blocks, 3); + + // Verbosity print + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("gramA:\n"); + raft::matrix::print(handle, make_const_mdspan(gramAView), ps); + printf("gramB:\n"); + raft::matrix::print(handle, make_const_mdspan(gramBView), ps); + } + bool eig_sucess = + eigh(handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); + if (!eig_sucess) restart = true; + } + if (restart) { + gramDim = gramXAX.extent(1) + gramXAR.extent(1); + std::vector> A_blocks = { + gramXAX.view(), gramXAR.view(), gramXART.view(), gramRAR.view()}; + std::vector> B_blocks = { + gramXBX.view(), gramXBR.view(), gramXBRT.view(), gramRBR.view()}; + gramAView = raft::make_device_matrix_view( + gramA.data_handle(), gramDim, gramDim); + gramBView = raft::make_device_matrix_view( + gramB.data_handle(), gramDim, gramDim); + eigLambdaTempView = + raft::make_device_vector_view(eigLambdaTempView.data_handle(), gramDim); + eigVectorTempView = raft::make_device_matrix_view( + eigVectorTempView.data_handle(), gramDim, gramDim); + bmat(handle, gramAView, A_blocks, 2); + bmat(handle, gramBView, B_blocks, 2); + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("gramA:\n"); + raft::matrix::print(handle, make_const_mdspan(gramAView), ps); + printf("gramB:\n"); + raft::matrix::print(handle, make_const_mdspan(gramBView), ps); + } + bool eig_sucess = eigh( + handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); + ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations"); + } + eigVectorBuffer.resize(gramDim * size_x, stream); + eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), gramDim, size_x); + truncEig( + handle, eigVectorTempView, std::make_optional(eigVectorView), eigLambdaTempView, largest); + raft::copy(eigLambda.data_handle(), eigLambdaTempView.data_handle(), size_x, stream); + + // Verbosity print + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("lambdaPostGram:\n"); + raft::matrix::print(handle, raft::make_device_matrix_view(eigLambdaTempView.data_handle(), 1, eigLambdaTempView.extent(0)), ps); + + } + + // Compute Ritz vectors. + auto d_one = raft::make_device_scalar(handle, 1); + auto one = std::make_optional(d_one.view()); + auto eigBlockVectorX = raft::make_device_matrix(handle, size_x, size_x); + auto eigBlockVectorR = raft::make_device_matrix(handle, currentBlockSize, size_x); + auto eigBlockVectorP = raft::make_device_matrix(handle, gramDim - (size_x + currentBlockSize), size_x); + auto pp = raft::make_device_matrix(handle, n, size_x); + auto app = raft::make_device_matrix(handle, n, size_x); + if (B_opt.has_value()) { + auto bpp = raft::make_device_matrix(handle, n, size_x); + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(), + raft::matrix::slice_coordinates(0, 0, size_x, size_x)); + if (!restart) { + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, size_x + currentBlockSize, size_x)); + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(), + raft::matrix::slice_coordinates(size_x + currentBlockSize, 0, gramDim, size_x)); + } else { + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, gramDim, size_x)); + } + + raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view()); + raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view()); + raft::linalg::gemm(handle, activeBRView, eigBlockVectorR.view(), bpp.view()); + if (!restart) { + raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one); + raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one); + raft::linalg::gemm(handle, activeBPView, eigBlockVectorP.view(), bpp.view(), one, one); + } + Pbuffer.resize(n * size_x, stream); + APbuffer.resize(n * size_x, stream); + BPbuffer.resize(n * size_x, stream); + PView = raft::make_device_matrix_view(Pbuffer.data(), n, size_x); + APView = raft::make_device_matrix_view(APbuffer.data(), n, size_x); + BPView = raft::make_device_matrix_view(BPbuffer.data(), n, size_x); + + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("pp:\n"); + raft::matrix::print(handle, make_const_mdspan(pp.view()), ps); + printf("app:\n"); + raft::matrix::print(handle, make_const_mdspan(app.view()), ps); + printf("bpp:\n"); + raft::matrix::print(handle, make_const_mdspan(bpp.view()), ps); + } + raft::copy(PView.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(APView.data_handle(), app.data_handle(), app.size(), stream); + raft::copy(BPView.data_handle(), bpp.data_handle(), bpp.size(), stream); + + raft::linalg::gemm(handle, X, eigBlockVectorX.view(), pp.view(), one, one); + raft::linalg::gemm(handle, AX.view(), eigBlockVectorX.view(), app.view(), one, one); + raft::linalg::gemm(handle, BXView, eigBlockVectorX.view(), bpp.view(), one, one); + + raft::copy(X.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream); + raft::copy(BXView.data_handle(), bpp.data_handle(), bpp.size(), stream); + } else { + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(), + raft::matrix::slice_coordinates(0, 0, size_x, size_x)); + if (!restart) { + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, size_x + currentBlockSize, size_x)); + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(), + raft::matrix::slice_coordinates(size_x + currentBlockSize, 0, gramDim, size_x)); + } else { + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, gramDim, size_x)); + } + + raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view()); + raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view()); + if (!restart) { + raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one); + raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one); + } + Pbuffer.resize(n * size_x, stream); + APbuffer.resize(n * size_x, stream); + PView = raft::make_device_matrix_view(Pbuffer.data(), n, size_x); + APView = raft::make_device_matrix_view(APbuffer.data(), n, size_x); + + raft::copy(PView.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(APView.data_handle(), app.data_handle(), app.size(), stream); + + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("pp:\n"); + raft::matrix::print(handle, make_const_mdspan(pp.view()), ps); + printf("app:\n"); + raft::matrix::print(handle, make_const_mdspan(app.view()), ps); + } + + raft::linalg::gemm(handle, X, eigBlockVectorX.view(), pp.view(), one, one); + raft::linalg::gemm(handle, AX.view(), eigBlockVectorX.view(), app.view(), one, one); + + raft::copy(X.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream); + } + } + + if (B_opt.has_value()) { + raft::copy(aux.data_handle(), BXView.data_handle(), BXView.size(), stream); + } else { + raft::copy(aux.data_handle(), X.data_handle(), X.size(), stream); + } + raft::linalg::binary_mult_skip_zero(handle, aux.view(), make_const_mdspan(eigLambda.view()), raft::linalg::Apply::ALONG_ROWS); + + raft::linalg::subtract( + handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); + + raft::linalg::reduce( + residual_norms.data_handle(), + R.data_handle(), + size_x, + n, + value_t(0), + false, + true, + stream, + false, + raft::sq_op()); + // TODO check reduce sqrt postop raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); + + if (verbosityLevel > 0) { + /// TODO add verb + } +} +}; // namespace raft::sparse::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/lobpcg.cuh b/cpp/include/raft/sparse/solver/lobpcg.cuh new file mode 100644 index 0000000000..36aa909012 --- /dev/null +++ b/cpp/include/raft/sparse/solver/lobpcg.cuh @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace raft::sparse::solver { + +template +void lobpcg( + const raft::handle_t& handle, + // IN + raft::spectral::matrix::sparse_matrix_t A, // shape=(n,n) + raft::device_matrix_view X, // shape=(n,k) IN OUT Eigvectors + raft::device_vector_view W, // shape=(k) OUT Eigvals + std::optional> B = + std::nullopt, // shape=(n,n) + std::optional> M = + std::nullopt, // shape=(n,n) + std::optional> Y = + std::nullopt, // Constraint matrix shape=(n,Y) + value_t tol = 0, + std::int32_t max_iter = 20, + bool largest = true) +{ + detail::lobpcg(handle, A, X, W, B, M, Y, tol, max_iter, largest); +} +}; // namespace raft::sparse::solver \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 22e8a9d73c..7a981cdd29 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -235,6 +235,7 @@ if(BUILD_TESTS) test/sparse/csr_transpose.cu test/sparse/degree.cu test/sparse/filter.cu + test/sparse/lobpcg.cu test/sparse/norm.cu test/sparse/reduce.cu test/sparse/row_op.cu diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 10105203f7..d09ea77b48 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -57,8 +58,20 @@ class MatrixTest : public ::testing::TestWithParam> { } protected: + void test_eye() + { + auto eyemat = raft::make_device_matrix(handle, 4, 5); + raft::matrix::eye(handle, eyemat.view()); + std::vector eye_exp{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}; + std::vector eye_act(20); + raft::copy(eye_act.data(), eyemat.data_handle(), eye_act.size(), stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT_TRUE(hostVecMatch(eye_exp, eye_act, raft::Compare())); + } + void SetUp() override { + test_eye(); raft::random::RngState r(params.seed); int len = params.n_row * params.n_col; uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu new file mode 100644 index 0000000000..e226772b9e --- /dev/null +++ b/cpp/test/sparse/lobpcg.cu @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include "../test_utils.cuh" +#include "../test_utils.h" +#include + +#include +#include + +namespace raft { +namespace sparse { + +template +struct CSRMatrixVal { + std::vector row_ind_ptr; + std::vector row_ind; + std::vector values; +}; + +template +struct LOBPCGInputs { + CSRMatrixVal matrix_a; + std::vector init_eigvecs; + std::vector exp_eigvals; + std::vector exp_eigvecs; + idx_t n_components; +}; + +// Helper for b_orthonormalize optional arguments +template +void b_orthonormalize(const raft::handle_t& handle, + raft::device_matrix_view V, + raft::device_matrix_view BV, + b_opt_t&& B_opt = std::nullopt, + vbv_opt_t&& VBV_opt = std::nullopt, + v_max_opt_t&& V_max_opt = std::nullopt, + bool bv_is_empty = true) +{ + std::optional> b = + std::forward(B_opt); + std::optional> vbv = + std::forward(VBV_opt); + std::optional> v_max = + std::forward(V_max_opt); + raft::sparse::solver::detail::b_orthonormalize(handle, V, BV, b, vbv, v_max, bv_is_empty); +} + +template +class LOBPCGTest : public ::testing::TestWithParam> { + public: + LOBPCGTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + ind_a(params.matrix_a.row_ind.size(), stream), + ind_ptr_a(params.matrix_a.row_ind_ptr.size(), stream), + values_a(params.matrix_a.values.size(), stream), + exp_eigvals(params.exp_eigvals.size(), stream), + exp_eigvecs(params.exp_eigvecs.size(), stream), + act_eigvals(params.exp_eigvals.size(), stream), + act_eigvecs(params.exp_eigvecs.size(), stream) + { + } + + protected: + void SetUp() override + { + n_rows_a = params.matrix_a.row_ind_ptr.size() - 1; + nnz_a = params.matrix_a.values.size(); + } + + void test_selectcolsif() + { + auto a = raft::make_device_matrix(handle, 5, 8); + auto c = raft::make_device_matrix(handle, 5, 4); + auto b = raft::make_device_vector(handle, 8); + raft::linalg::range(a.data_handle(), a.size(), handle.get_stream()); + std::vector select_h{0, 1, 1, 1, 0, 0, 0, 1}; + raft::copy(b.data_handle(), select_h.data(), 8, handle.get_stream()); + raft::sparse::solver::detail::selectColsIf(handle, a.view(), b.view(), c.view()); + std::vector res(c.size()); + std::vector expected{5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39}; + raft::copy(res.data(), c.data_handle(), c.size(), handle.get_stream()); + + ASSERT_TRUE(hostVecMatch(expected, res, raft::CompareApprox(0.0001))); + } + + void test_bmat() + { + auto total = raft::make_device_matrix(handle, 6, 6); + auto x1 = raft::make_device_matrix(handle, 2, 2); + auto x2 = raft::make_device_matrix(handle, 2, 2); + auto x3 = raft::make_device_matrix(handle, 2, 2); + auto x4 = raft::make_device_matrix(handle, 2, 2); + auto x5 = raft::make_device_matrix(handle, 2, 2); + auto x6 = raft::make_device_matrix(handle, 2, 2); + auto x7 = raft::make_device_matrix(handle, 2, 2); + auto x8 = raft::make_device_matrix(handle, 2, 2); + auto x9 = raft::make_device_matrix(handle, 2, 2); + raft::linalg::range(x1.data_handle(), 0, 4, handle.get_stream()); + raft::linalg::range(x2.data_handle(), 4, 8, handle.get_stream()); + raft::linalg::range(x3.data_handle(), 8, 12, handle.get_stream()); + raft::linalg::range(x4.data_handle(), 12, 16, handle.get_stream()); + raft::linalg::range(x5.data_handle(), 16, 20, handle.get_stream()); + raft::linalg::range(x6.data_handle(), 20, 24, handle.get_stream()); + raft::linalg::range(x7.data_handle(), 24, 28, handle.get_stream()); + raft::linalg::range(x8.data_handle(), 28, 32, handle.get_stream()); + raft::linalg::range(x9.data_handle(), 32, 36, handle.get_stream()); + std::vector> xs = {x1.view(), + x2.view(), + x3.view(), + x4.view(), + x5.view(), + x6.view(), + x7.view(), + x8.view(), + x9.view()}; + raft::sparse::solver::detail::bmat(handle, total.view(), xs, 3); + std::vector res(total.size()); + std::vector expected{0, 1, 12, 13, 24, 25, 2, 3, 14, 15, 26, 27, + 4, 5, 16, 17, 28, 29, 6, 7, 18, 19, 30, 31, + 8, 9, 20, 21, 32, 33, 10, 11, 22, 23, 34, 35}; + raft::copy(res.data(), total.data_handle(), total.size(), handle.get_stream()); + handle.sync_stream(); + ASSERT_TRUE(hostVecMatch(expected, res, raft::CompareApprox(0.0001))); + } + + void test_b_orthonormalize() + { + idx_t n_rows_v = n_rows_a; + idx_t n_features_v = params.n_components; + raft::update_device(act_eigvecs.data(), params.init_eigvecs.data(), act_eigvecs.size(), stream); + auto v = raft::make_device_matrix_view( + act_eigvecs.data(), n_rows_v, n_features_v); + auto bv = + raft::make_device_matrix(handle, n_rows_v, n_features_v); + auto vbv = + raft::make_device_matrix(handle, n_features_v, n_features_v); + b_orthonormalize( + handle, v, bv.view(), std::nullopt, std::make_optional(vbv.view()), std::nullopt, true); + std::vector vbv_inv_expected{0.76298383, 0.0, -1.20276028, 1.0791533}; + std::vector vbv_inv_actual(4); + raft::copy(vbv_inv_actual.data(), vbv.data_handle(), vbv_inv_actual.size(), stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT_TRUE( + hostVecMatch(vbv_inv_expected, vbv_inv_actual, raft::CompareApprox(0.0001))); + } + + void test_eigh() + { + std::vector in_cpu{1.73969722, 0.98719877, 0.73374337, 0.211756781}; + std::vector lambda_cpu{-0.27255666, 2.22401026}; + std::vector vector_cpu{-0.44044489, 0.89777965, 0.89777965, 0.44044489}; + auto in_gpu = raft::make_device_matrix(handle, 2, 2); + auto lambda_gpu = raft::make_device_vector(handle, 2); + auto vector_gpu = raft::make_device_matrix(handle, 2, 2); + std::optional> empty_matrix_opt = std::nullopt; + + raft::copy(in_gpu.data_handle(), in_cpu.data(), 4, handle.get_stream()); + raft::sparse::solver::detail::eigh(handle, in_gpu.view(), empty_matrix_opt, vector_gpu.view(), lambda_gpu.view()); + + ASSERT_TRUE(devArrMatchHost(lambda_cpu.data(), lambda_gpu.data_handle(), lambda_cpu.size(), raft::CompareApprox(0.0001), handle.get_stream())); + ASSERT_TRUE(devArrMatchHost(vector_cpu.data(), vector_gpu.data_handle(), vector_cpu.size(), raft::CompareApprox(0.0001), handle.get_stream())); + } + + void Run() + { + test_eigh(); + test_bmat(); + test_selectcolsif(); + test_b_orthonormalize(); + raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), params.matrix_a.row_ind.size(), stream); + raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), params.matrix_a.row_ind_ptr.size(), stream); + raft::update_device(values_a.data(), params.matrix_a.values.data(), params.matrix_a.values.size(), stream); + + raft::update_device(act_eigvecs.data(), params.init_eigvecs.data(), act_eigvecs.size(), stream); + + auto matA = raft::spectral::matrix::sparse_matrix_t( + handle, ind_ptr_a.data(), ind_a.data(), values_a.data(), n_rows_a, n_rows_a, nnz_a); + raft::sparse::solver::lobpcg( + handle, + matA, + raft::make_device_matrix_view( + act_eigvecs.data(), n_rows_a, params.n_components), + raft::make_device_vector_view(act_eigvals.data(), n_rows_a)); + + std::vector X_CPU(n_rows_a * params.n_components); + std::vector W_CPU(n_rows_a); + raft::copy(X_CPU.data(), act_eigvecs.data(), X_CPU.size(), stream); + raft::copy(W_CPU.data(), act_eigvals.data(), W_CPU.size(), stream); + ASSERT_TRUE(raft::devArrMatch(exp_eigvecs.data(), + act_eigvecs.data(), + exp_eigvecs.size(), + raft::CompareApprox(0.0001), + stream)); + ASSERT_TRUE(raft::devArrMatch(exp_eigvals.data(), + act_eigvals.data(), + exp_eigvals.size(), + raft::CompareApprox(0.0001), + stream)); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + LOBPCGInputs params; + idx_t n_rows_a, nnz_a; + rmm::device_uvector ind_a, ind_ptr_a; + rmm::device_uvector values_a, exp_eigvals, exp_eigvecs, act_eigvals, act_eigvecs; +}; + +using LOBPCGTestF = LOBPCGTest; +TEST_P(LOBPCGTestF, Result) { Run(); } + +using LOBPCGTestD = LOBPCGTest; +TEST_P(LOBPCGTestD, Result) { Run(); } + +const std::vector> lobpcg_inputs_f = { + {{{0, 4, 10, 14, 19, 24, 28}, + {0, 2, 3, 5, 0, 1, 2, 3, 4, 5, 0, 2, 3, 5, 1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 2, 3, 4}, + {0.37911922, 0.11567201, 0.5135106, 0.08968836, 0.73450965, 0.26432646, 0.21985123, + 0.74888277, 0.34753734, 0.11204864, 0.82902676, 0.53023521, 0.24047095, 0.37913592, + 0.60975031, 0.60746519, 0.96833343, 0.30845102, 0.88653955, 0.43530847, 0.32938903, + 0.82477561, 0.20858375, 0.24755519, 0.23677223, 0.73957246, 0.09050876, 0.86530489}}, + {0.08319983, + 0.17758466, + 0.93301819, + 0.67171826, + 0.19967821, + 0.30873092, + 0.35005079, + 0.56035486, + 0.64176631, + 0.93904784, + 0.38935935, + 0.97182089}, + {2.61153278, 0.85782948}, + {-0.38272064, + -0.25160901, + -0.48684676, + -0.50752949, + -0.43005954, + -0.33265696, + -0.39778489, + 0.2539629, + -0.37506003, + 0.72637041, + 0.02727131, + -0.32900198}, + 2}}; +const std::vector> lobpcg_inputs_d = { + {{{0, 4, 10, 14, 19, 24, 28}, + {0, 2, 3, 5, 0, 1, 2, 3, 4, 5, 0, 2, 3, 5, 1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 2, 3, 4}, + {0.37911922, 0.11567201, 0.5135106, 0.08968836, 0.73450965, 0.26432646, 0.21985123, + 0.74888277, 0.34753734, 0.11204864, 0.82902676, 0.53023521, 0.24047095, 0.37913592, + 0.60975031, 0.60746519, 0.96833343, 0.30845102, 0.88653955, 0.43530847, 0.32938903, + 0.82477561, 0.20858375, 0.24755519, 0.23677223, 0.73957246, 0.09050876, 0.86530489}}, + {0.08319983, + 0.17758466, + 0.93301819, + 0.67171826, + 0.19967821, + 0.30873092, + 0.35005079, + 0.56035486, + 0.64176631, + 0.93904784, + 0.38935935, + 0.97182089}, + {2.61153278, 0.85782948}, + {-0.38272064, + -0.25160901, + -0.48684676, + -0.50752949, + -0.43005954, + -0.33265696, + -0.39778489, + 0.2539629, + -0.37506003, + 0.72637041, + 0.02727131, + -0.32900198}, + 2}}; + +INSTANTIATE_TEST_CASE_P(LOBPCGTest, LOBPCGTestF, ::testing::ValuesIn(lobpcg_inputs_f)); +INSTANTIATE_TEST_CASE_P(LOBPCGTest, LOBPCGTestD, ::testing::ValuesIn(lobpcg_inputs_d)); + +} // namespace sparse +} // namespace raft + +/* + +a=cupyx.scipy.sparse.random(6,6, 0.8,'csr') +a.indptr = array([ 0, 4, 10, 14, 19, 24, 28], dtype=int32) + +a.indices = array([0, 2, 3, 5, 0, 1, 2, 3, 4, 5, 0, 2, 3, 5, 1, 2, 3, 4, 5, 0, 2, 3, + 4, 5, 0, 2, 3, 4], dtype=int32) + +a.data = array([0.37911922, 0.11567201, 0.5135106 , 0.08968836, 0.73450965, + 0.26432646, 0.21985123, 0.74888277, 0.34753734, 0.11204864, + 0.82902676, 0.53023521, 0.24047095, 0.37913592, 0.60975031, + 0.60746519, 0.96833343, 0.30845102, 0.88653955, 0.43530847, + 0.32938903, 0.82477561, 0.20858375, 0.24755519, 0.23677223, + 0.73957246, 0.09050876, 0.86530489]) + +a.todense() = +np.matrix([[0.37911922, 0. , 0.11567201, 0.5135106 , 0. , 0.08968836], + [0.73450965, 0.26432646, 0.21985123, 0.74888277, 0.34753734, 0.11204864], + [0.82902676, 0. , 0.53023521, 0.24047095, 0. , 0.37913592], + [0. , 0.60975031, 0.60746519, 0.96833343, 0.30845102, 0.88653955], + [0.43530847, 0. , 0.32938903, 0.82477561, 0.20858375, 0.24755519], + [0.23677223, 0. , 0.73957246, 0.09050876, 0.86530489, 0. ]]) +x = np.random.rand(6,2) +x = np.array([[0.08319983, 0.35005079], + [0.17758466, 0.56035486], + [0.93301819, 0.64176631], + [0.67171826, 0.93904784], + [0.19967821, 0.38935935], + [0.30873092, 0.97182089]]) + +lobpcg(a, x) = (array([2.61153278, 0.85782948]), + array([[-0.38272064, -0.39778489], + [-0.25160901, 0.2539629 ], + [-0.48684676, -0.37506003], + [-0.50752949, 0.72637041], + [-0.43005954, 0.02727131], + [-0.33265696, -0.32900198]])) + */ \ No newline at end of file