Skip to content

Commit

Permalink
Fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Apr 19, 2023
1 parent 03c1760 commit 6694ae4
Showing 1 changed file with 70 additions and 46 deletions.
116 changes: 70 additions & 46 deletions cpp/include/raft/sparse/solver/detail/lobpcg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ void selectColsIf(const raft::handle_t& handle,
raft::linalg::map(
handle,
raft::make_const_mdspan(mask),
raft::make_const_mdspan(rangeVec.view()),
rangeVec.view(),
[] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; },
rangeVec.view());
[] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; });
thrust::sort(rmm::exec_policy(stream),
rangeVec.data_handle(),
rangeVec.data_handle() + rangeVec.size(),
Expand Down Expand Up @@ -172,11 +172,11 @@ void truncEig(
}
if (eigVectorTrunc.has_value() && ncols > eigVectorTrunc->extent(1))
raft::matrix::truncZeroOrigin(eigVectorin.data_handle(),
n_rows,
nrows,
eigVectorTrunc->data_handle(),
nrows,
eigVectorTrunc->extent(1),
stream);
handle.get_stream());
}

// C = A * B
Expand Down Expand Up @@ -447,7 +447,7 @@ bool eigh(const raft::handle_t& handle,

raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals);
raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs);
return cho_success
return cho_success;
}

/**
Expand Down Expand Up @@ -604,8 +604,10 @@ void lobpcg(
auto eigVectorBuffer = rmm::device_uvector<value_t>(size_x * size_x, stream); // rmm because of resize
auto eigVectorView = raft::make_device_matrix_view<value_t, index_t, raft::col_major>(eigVectorBuffer.data(), size_x, size_x);
auto eigLambda = raft::make_device_vector<value_t, index_t>(handle, size_x);
eigh(handle, gramXAX.view(), eigVectorView, eigLambda.view());
truncEig(handle, eigVectorView, eigLambda.view(), size_x, largest);
std::optional<raft::device_matrix_view<value_t, index_t, raft::col_major>> empty_matrix_opt = std::nullopt;
eigh(handle, gramXAX.view(), empty_matrix_opt, eigVectorView, eigLambda.view());

truncEig(handle, eigVectorView, empty_matrix_opt, eigLambda.view(), largest);
// Slice not needed for first eigh
// raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0,
// eigVectorFull.extent(0), size_x));
Expand All @@ -623,6 +625,9 @@ void lobpcg(
auto identView = raft::make_device_matrix_view<value_t, index_t, raft::col_major>(
ident.data(), size_x, size_x);
raft::matrix::eye(handle, identView);
auto identSizeX = raft::make_device_matrix<value_t, index_t, raft::col_major>(
handle, size_x, size_x);
raft::matrix::eye(handle, identSizeX.view());

auto Pbuffer = rmm::device_uvector<value_t>(0, stream);
auto APbuffer = rmm::device_uvector<value_t>(0, stream);
Expand All @@ -646,6 +651,8 @@ void lobpcg(

auto aux = raft::make_device_matrix<value_t, index_t, raft::col_major>(
handle, n, size_x);
//auto aux_sum = raft::make_device_vector<value_t, index_t>(handle, size_x);
auto residual_norms = raft::make_device_vector<value_t, index_t>(handle, size_x);
std::int32_t iteration_number = -1;
bool restart = true;
bool explicitGramFlag = false;
Expand All @@ -664,9 +671,8 @@ void lobpcg(
raft::linalg::subtract(
handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view());

auto aux_sum = raft::make_device_vector<value_t, index_t>(handle, size_x);
raft::linalg::reduce(
aux_sum.data_handle(),
residual_norms.data_handle(),
R.data_handle(),
size_x,
n,
Expand All @@ -677,8 +683,7 @@ void lobpcg(
false,
raft::sq_op());

auto residual_norms = raft::make_device_vector<value_t, index_t>(handle, size_x);
raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
// TODO check sqop of reduce raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());

// cupy where & active_mask
raft::linalg::unary_op(handle,
Expand Down Expand Up @@ -720,7 +725,7 @@ void lobpcg(
selectColsIf(handle, APView, active_mask.view(), activeAPView);
if (B_opt.has_value()) {
activeBPView = raft::make_device_matrix_view<value_t, index_t, col_major>(activeBPbuffer.data(), n, currentBlockSize);
selectColsIf(handle, BPbuffer.view(), active_mask.view(), activeBPView);
selectColsIf(handle, BPView, active_mask.view(), activeBPView);
}
}
if (M_opt.has_value()) {
Expand Down Expand Up @@ -823,7 +828,7 @@ void lobpcg(

if (!B_opt.has_value()) {
// Shared memory assignments to simplify the code
BXView = X.view();
BXView = X;
activeBRView = activeR.view();
if (!restart)
activeBPView = activePView;
Expand Down Expand Up @@ -906,9 +911,9 @@ void lobpcg(
auto gramB = raft::make_device_matrix<value_t, index_t, col_major>(handle, gramDim, gramDim);
auto gramAView = gramA.view();
auto gramBView = gramB.view();
auto eigLambdaTemp = raft::make_device_vector_view<value_t, index_t>(handle, gramDim);
auto eigLambdaTemp = raft::make_device_vector<value_t, index_t>(handle, gramDim);
auto eigVectorTemp =
raft::make_device_matrix_view<value_t, index_t, raft::col_major>(handle, gramDim, gramDim);
raft::make_device_matrix<value_t, index_t, raft::col_major>(handle, gramDim, gramDim);
auto eigLambdaTempView = eigLambdaTemp.view();
auto eigVectorTempView = eigVectorTemp.view();
eigVectorBuffer.resize(gramDim * size_x, stream);
Expand All @@ -927,19 +932,19 @@ void lobpcg(
handle, currentBlockSize, currentBlockSize);
// create transpose mat
auto gramXAPT = raft::make_device_matrix<value_t, index_t, col_major>(
handle, gramXAPT.extent(1), gramXAPT.extent(0));
handle, gramXAP.extent(1), gramXAP.extent(0));
auto gramXART = raft::make_device_matrix<value_t, index_t, col_major>(
handle, gramXART.extent(1), gramXART.extent(0));
handle, gramXAR.extent(1), gramXAR.extent(0));
auto gramRAPT = raft::make_device_matrix<value_t, index_t, col_major>(
handle, gramRAPT.extent(1), gramRAPT.extent(0));
handle, gramRAP.extent(1), gramRAP.extent(0));
auto gramXBPT = raft::make_device_matrix<value_t, index_t, col_major>(
handle, gramXBPT.extent(1), gramXBPT.extent(0));
handle, gramXBP.extent(1), gramXBP.extent(0));
auto gramXBRT = raft::make_device_matrix<value_t, index_t, col_major>(
handle, gramXBRT.extent(1), gramXBRT.extent(0));
handle, gramXBR.extent(1), gramXBR.extent(0));
auto gramRBPT = raft::make_device_matrix<value_t, index_t, col_major>(
handle, gramRBPT.extent(1), gramRBPT.extent(0));
handle, gramRBP.extent(1), gramRBP.extent(0));
raft::linalg::transpose(handle, gramXAR.view(), gramXART.view());
raft::linalg::transpose(handle, gramXVR.view(), gramXBRT.view());
raft::linalg::transpose(handle, gramXBR.view(), gramXBRT.view());

if (!restart) {
raft::linalg::gemm(handle,
Expand Down Expand Up @@ -1005,19 +1010,19 @@ void lobpcg(
gramBView =
raft::make_device_matrix_view<value_t, index_t, col_major>(gramB.data_handle(), n, n);

bmat(handle, gramAView, A_blocks);
bmat(handle, gramBView, B_blocks);
bmat(handle, gramAView, A_blocks, 3);
bmat(handle, gramBView, B_blocks, 3);

bool eig_sucess =
eigh(handle, gramA, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView);
eigh(handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView);
if (!eig_sucess) restart = true;
}
if (restart) {
gramDim = gramXAX.extent(1) + gramXAR.extent(1);
std::vector<raft::device_matrix_view<value_t, index_t, col_major>> A_blocks = {
gramXAX, gramXAR, gramXART, gramRAR};
gramXAX.view(), gramXAR.view(), gramXART.view(), gramRAR.view()};
std::vector<raft::device_matrix_view<value_t, index_t, col_major>> B_blocks = {
gramXBX, gramXBR, gramXBRT, gramRBR};
gramXBX.view(), gramXBR.view(), gramXBRT.view(), gramRBR.view()};
gramAView = raft::make_device_matrix_view<value_t, index_t, col_major>(
gramA.data_handle(), gramDim, gramDim);
gramBView = raft::make_device_matrix_view<value_t, index_t, col_major>(
Expand All @@ -1026,8 +1031,8 @@ void lobpcg(
raft::make_device_vector_view<value_t, index_t>(eigLambdaTempView.data_handle(), gramDim);
eigVectorTempView = raft::make_device_matrix_view<value_t, index_t, col_major>(
eigVectorTempView.data_handle(), gramDim, gramDim);
bmat(handle, gramAView, A_blocks);
bmat(handle, gramBView, B_blocks);
bmat(handle, gramAView, A_blocks, 2);
bmat(handle, gramBView, B_blocks, 2);
bool eig_sucess = eigh(
handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView);
ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations");
Expand All @@ -1048,20 +1053,20 @@ void lobpcg(
auto app = raft::make_device_matrix<value_t, index_t, raft::col_major>(handle, n, size_x);
if (B_opt.has_value()) {
auto bpp = raft::make_device_matrix<value_t, index_t, raft::col_major>(handle, n, size_x);
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(),
raft::matrix::slice_coordinates<index_t>(0, 0, size_x, size_x));
if (!restart) {
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice_coordinates<index_t>(size_x, 0, size_x + currentBlockSize, size_x));
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(),
raft::matrix::slice_coordinates<index_t>(size_x + currentBlockSize, 0, gramDim, size_x));
} else {
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice_coordinates<index_t>(size_x, 0, gramDim, size_x));
}

raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view());
raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view());
raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view());
raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view());
raft::linalg::gemm(handle, activeBRView, eigBlockVectorR.view(), bpp.view());
if (!restart) {
raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one);
Expand All @@ -1087,20 +1092,20 @@ void lobpcg(
raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream);
raft::copy(BXView.data_handle(), bpp.data_handle(), bpp.size(), stream);
} else {
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(),
raft::matrix::slice_coordinates<index_t>(0, 0, size_x, size_x));
if (!restart) {
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice_coordinates<index_t>(size_x, 0, size_x + currentBlockSize, size_x));
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(),
raft::matrix::slice_coordinates<index_t>(size_x + currentBlockSize, 0, gramDim, size_x));
} else {
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
raft::matrix::slice_coordinates<index_t>(size_x, 0, gramDim, size_x));
}

raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view());
raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view());
raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view());
raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view());
if (!restart) {
raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one);
raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one);
Expand All @@ -1121,12 +1126,31 @@ void lobpcg(
}
}

if (B_opt.has_value()) { // Using blockVectorR instead of aux
raft::copy(R.data_handle(), BXView.data_handle(), BXView.size(), stream);
if (B_opt.has_value()) {
raft::copy(aux.data_handle(), BXView.data_handle(), BXView.size(), stream);
} else {
raft::copy(R.data_handle(), X.data_handle(), X.size(), stream);
raft::copy(aux.data_handle(), X.data_handle(), X.size(), stream);
}
raft::linalg::binary_mult_skip_zero(handle, aux.view(), make_const_mdspan(eigLambda.view()), raft::linalg::Apply::ALONG_ROWS);

raft::linalg::subtract(
handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view());

raft::linalg::reduce(
residual_norms.data_handle(),
R.data_handle(),
size_x,
n,
value_t(0),
false,
true,
stream,
false,
raft::sq_op());
// TODO check reduce sqrt postop raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());

if (verbosityLevel > 0) {
/// TODO add verb
}
raft::linalg::binary_mult_skip_zero(handle, R.view(), make_const_mdspan(eigLambda.view()), linalg::Apply::ALONG_ROWS);
raft::linalg::gemm(handle, AX.view(),)
}
}; // namespace raft::sparse::solver::detail

0 comments on commit 6694ae4

Please sign in to comment.