Skip to content

Commit

Permalink
implement batched serial getrf (#2331)
Browse files Browse the repository at this point in the history
* fix: conflicts

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: gpu version

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: docstring for getrf

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: conflicts

Signed-off-by: Yuuichi Asahi <[email protected]>

* format

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: errors from code style

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: format

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: conflicts

Signed-off-by: Yuuichi Asahi <[email protected]>

* Improve implementation details of getrf

Signed-off-by: Yuuichi Asahi <[email protected]>

* format

Signed-off-by: Yuuichi Asahi <[email protected]>

* Update create_triangular_matrix function

Signed-off-by: Yuuichi Asahi <[email protected]>

* Merging Test_Batched_SerialGetrf.hpp and Test_Batched_SerialGetrf_Real.hpp

Signed-off-by: Yuuichi Asahi <[email protected]>

* remove the global namespace

Signed-off-by: Yuuichi Asahi <[email protected]>

* Add missing maybe_unused for checkGetrfInput

Signed-off-by: Yuuichi Asahi <[email protected]>

* Improve docstrings and comments to describe getrf algo

Signed-off-by: Yuuichi Asahi <[email protected]>

* Add a complicated analytical test based on review

Signed-off-by: Yuuichi Asahi <[email protected]>

---------

Signed-off-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jan 13, 2025
1 parent 89814c7 commit 226d4ac
Show file tree
Hide file tree
Showing 6 changed files with 907 additions and 0 deletions.
67 changes: 67 additions & 0 deletions batched/dense/impl/KokkosBatched_Getrf_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_

#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Getrf_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <typename AViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int checkGetrfInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const PivViewType &ipiv) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::getrf: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<PivViewType>, "KokkosBatched::getrf: PivViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::getrf: AViewType must have rank 2.");
static_assert(PivViewType::rank == 1, "KokkosBatched::getrf: PivViewType must have rank 1.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int m = A.extent(0), n = A.extent(1);
const int npiv = ipiv.extent(0);
if (npiv != Kokkos::min(m, n)) {
Kokkos::printf(
"KokkosBatched::getrf: the dimension of the ipiv array must "
"satisfy ipiv.extent(0) == max(m, n): ipiv: %d, A: "
"%d "
"x %d \n",
npiv, m, n);
return 1;
}

#endif
return 0;
}
} // namespace Impl

template <>
struct SerialGetrf<Algo::Getrf::Unblocked> {
template <typename AViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &ipiv) {
// Quick return if possible
if (A.extent(0) == 0 || A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkGetrfInput(A, ipiv);
if (info) return info;
KOKKOS_IF_ON_HOST((return KokkosBatched::Impl::SerialGetrfInternalHost<Algo::Getrf::Unblocked>::invoke(A, ipiv);))
KOKKOS_IF_ON_DEVICE(
(return KokkosBatched::Impl::SerialGetrfInternalDevice<Algo::Getrf::Unblocked>::invoke(A, ipiv);))
}
};

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_
315 changes: 315 additions & 0 deletions batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GETRF_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_GETRF_SERIAL_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>
#include <KokkosBlas1_scal.hpp>
#include <KokkosBatched_Trsm_Decl.hpp>
#include <KokkosBatched_Gemm_Decl.hpp>
#include <KokkosBatched_Iamax.hpp>
#include <KokkosBatched_Laswp.hpp>

namespace KokkosBatched {
namespace Impl {

struct Stack {
private:
constexpr static int STACK_SIZE = 48;

// (state, m_start, n_start, piv_start, m_size, n_size, piv_size)
int m_stack[7][STACK_SIZE];
int m_top;

public:
KOKKOS_FUNCTION
Stack() : m_top(-1) {} // Initialize top to -1, indicating the stack is empty

KOKKOS_INLINE_FUNCTION
void push(int values[]) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
if (m_top >= STACK_SIZE - 1) {
Kokkos::printf("Stack overflow: Cannot push, the stack is full.\n");
return;
}
#endif
++m_top;
for (int i = 0; i < 7; i++) {
// Increment top and add value
m_stack[i][m_top] = values[i];
}
}

KOKKOS_INLINE_FUNCTION
void pop(int values[]) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
if (m_top < 0) {
// Check if the stack is empty
Kokkos::printf("Stack underflow: Cannot pop, the stack is empty.");
return;
}
#endif
for (int i = 0; i < 7; i++) {
// Return the top value and decrement top
values[i] = m_stack[i][m_top];
}
m_top--;
}

KOKKOS_INLINE_FUNCTION
bool isEmpty() const { return m_top == -1; }
};

// Host only implementation with recursive algorithm
template <typename AlgoType>
struct SerialGetrfInternalHost {
template <typename AViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &ipiv);
};

template <>
template <typename AViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION int SerialGetrfInternalHost<Algo::Getrf::Unblocked>::invoke(const AViewType &A,
const PivViewType &ipiv) {
using ScalarType = typename AViewType::non_const_value_type;

const int m = A.extent(0), n = A.extent(1);

// Quick return if possible
if (m <= 0 || n <= 0) return 0;

int info = 0;

// Use unblocked code for one row case
// Just need to handle ipiv and info
if (m == 1) {
ipiv(0) = 0;
if (A(0, 0) == 0) return 1;

return 0;
} else if (n == 1) {
// Use unblocked code for one column case
// Compute machine safe minimum
auto col_A = Kokkos::subview(A, Kokkos::ALL, 0);

int i = SerialIamax::invoke(col_A);
ipiv(0) = i;

if (A(i, 0) == 0) return 1;

// Apply the interchange
if (i != 0) {
Kokkos::kokkos_swap(A(i, 0), A(0, 0));
}

// Compute elements
const ScalarType alpha = 1.0 / A(0, 0);
auto sub_col_A = Kokkos::subview(A, Kokkos::pair<int, int>(1, m), 0);
[[maybe_unused]] auto info_scal = KokkosBlas::SerialScale::invoke(alpha, sub_col_A);

return 0;
} else {
// Use recursive code
auto n1 = Kokkos::min(m, n) / 2;

// Factor A0 = [[A00],
// [A10]]

// split A into two submatrices A = [A0, A1]
auto A0 = Kokkos::subview(A, Kokkos::ALL, Kokkos::pair<int, int>(0, n1));
auto A1 = Kokkos::subview(A, Kokkos::ALL, Kokkos::pair<int, int>(n1, n));
auto ipiv0 = Kokkos::subview(ipiv, Kokkos::pair<int, int>(0, n1));
auto iinfo = invoke(A0, ipiv0);

if (info == 0 && iinfo > 0) info = iinfo;

// Apply interchanges to A1 = [[A01],
// [A11]]

[[maybe_unused]] auto info_laswp = KokkosBatched::SerialLaswp<Direct::Forward>::invoke(ipiv0, A1);

// split A into four submatrices
// A = [[A00, A01],
// [A10, A11]]
auto A00 = Kokkos::subview(A, Kokkos::pair<int, int>(0, n1), Kokkos::pair<int, int>(0, n1));
auto A01 = Kokkos::subview(A, Kokkos::pair<int, int>(0, n1), Kokkos::pair<int, int>(n1, n));
auto A10 = Kokkos::subview(A, Kokkos::pair<int, int>(n1, m), Kokkos::pair<int, int>(0, n1));
auto A11 = Kokkos::subview(A, Kokkos::pair<int, int>(n1, m), Kokkos::pair<int, int>(n1, n));

// Solve A00 * X = A01
[[maybe_unused]] auto info_trsm = KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Trsm::Unblocked>::invoke(1.0, A00, A01);

// Update A11 = A11 - A10 * A01
[[maybe_unused]] auto info_gemm =
KokkosBatched::SerialGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Unblocked>::invoke(-1.0, A10, A01,
1.0, A11);

// Factor A11
auto ipiv1 = Kokkos::subview(ipiv, Kokkos::pair<int, int>(n1, Kokkos::min(m, n)));
iinfo = invoke(A11, ipiv1);

if (info == 0 && iinfo > 0) info = iinfo + n1;

// Apply interchanges to A10
info_laswp = KokkosBatched::SerialLaswp<Direct::Forward>::invoke(ipiv1, A10);

// Pivot indices
for (int i = n1; i < Kokkos::min(m, n); i++) {
ipiv(i) += n1;
}

return info;
}
}

// Device only implementation with recursive algorithm
template <typename AlgoType>
struct SerialGetrfInternalDevice {
template <typename AViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &ipiv);
};

template <>
template <typename AViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION int SerialGetrfInternalDevice<Algo::Getrf::Unblocked>::invoke(const AViewType &A,
const PivViewType &ipiv) {
using ScalarType = typename AViewType::non_const_value_type;

const int m = A.extent(0), n = A.extent(1), init_piv_size = ipiv.extent(0);

Stack stack;
int initial[7] = {0, 0, 0, 0, m, n, init_piv_size};
stack.push(initial);

// Quick return if possible
if (m <= 0 || n <= 0) return 0;

while (!stack.isEmpty()) {
// Firstly, make a subview based on the current state
int current[7];
stack.pop(current);

int state = current[0], m_start = current[1], n_start = current[2], piv_start = current[3], m_size = current[4],
n_size = current[5], piv_size = current[6];

// Quick return if possible
if (m_size <= 0 || n_size <= 0) continue;

auto A_current = Kokkos::subview(A, Kokkos::pair<int, int>(m_start, m_start + m_size),
Kokkos::pair<int, int>(n_start, n_start + n_size));

auto ipiv_current = Kokkos::subview(ipiv, Kokkos::pair<int, int>(piv_start, piv_start + piv_size));
auto n1 = Kokkos::min(m_size, n_size) / 2;

// split A into two submatrices A = [A0, A1]
auto A0 = Kokkos::subview(A_current, Kokkos::ALL, Kokkos::pair<int, int>(0, n1));
auto A1 = Kokkos::subview(A_current, Kokkos::ALL, Kokkos::pair<int, int>(n1, n_size));
auto ipiv0 = Kokkos::subview(ipiv_current, Kokkos::pair<int, int>(0, n1));
auto ipiv1 = Kokkos::subview(ipiv_current, Kokkos::pair<int, int>(n1, Kokkos::min(m_size, n_size)));

// split A into four submatrices
// A = [[A00, A01],
// [A10, A11]]
auto A00 = Kokkos::subview(A_current, Kokkos::pair<int, int>(0, n1), Kokkos::pair<int, int>(0, n1));
auto A01 = Kokkos::subview(A_current, Kokkos::pair<int, int>(0, n1), Kokkos::pair<int, int>(n1, n_size));
auto A10 = Kokkos::subview(A_current, Kokkos::pair<int, int>(n1, m_size), Kokkos::pair<int, int>(0, n1));
auto A11 = Kokkos::subview(A_current, Kokkos::pair<int, int>(n1, m_size), Kokkos::pair<int, int>(n1, n_size));

if (state == 0) {
// start state
if (m_size == 1) {
ipiv_current(0) = 0;
if (A_current(0, 0) == 0) return 1;
continue;
} else if (n_size == 1) {
// Use unblocked code for one column case
// Compute machine safe minimum
auto col_A = Kokkos::subview(A_current, Kokkos::ALL, 0);

int i = SerialIamax::invoke(col_A);
ipiv_current(0) = i;

if (A_current(i, 0) == 0) return 1;

// Apply the interchange
if (i != 0) {
Kokkos::kokkos_swap(A_current(i, 0), A_current(0, 0));
}

// Compute elements
const ScalarType alpha = 1.0 / A_current(0, 0);
auto sub_col_A = Kokkos::subview(A_current, Kokkos::pair<int, int>(1, m_size), 0);
[[maybe_unused]] auto info_scal = KokkosBlas::SerialScale::invoke(alpha, sub_col_A);
continue;
}

// Push states onto the stack in reverse order of how they are executed
// in the recursive version
int after_second[7] = {2, m_start, n_start, piv_start, m_size, n_size, piv_size};
int second[7] = {0,
m_start + n1,
n_start + n1,
piv_start + n1,
m_size - n1,
n_size - n1,
static_cast<int>(Kokkos::min(m_size, n_size)) - n1};
int after_first[7] = {1, m_start, n_start, piv_start, m_size, n_size, piv_size};
int first[7] = {0, m_start, n_start, piv_start, m_size, n1, n1};

stack.push(after_second);
stack.push(second);
stack.push(after_first);
stack.push(first);

} else if (state == 1) {
// after first recursive call
// Factor A0 = [[A00],
// [A10]]

// Apply interchanges to A1 = [[A01],
// [A11]]
KokkosBatched::SerialLaswp<Direct::Forward>::invoke(ipiv0, A1);

// Solve A00 * X = A01
[[maybe_unused]] auto info_trsm =
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Trsm::Unblocked>::invoke(1.0, A00, A01);

// Update A11 = A11 - A10 * A01
[[maybe_unused]] auto info_gemm =
KokkosBatched::SerialGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Unblocked>::invoke(
-1.0, A10, A01, 1.0, A11);

} else if (state == 2) {
// after second recursive call
// Apply interchanges to A10
KokkosBatched::SerialLaswp<Direct::Forward>::invoke(ipiv1, A10);

// Pivot indices
for (int i = n1; i < Kokkos::min(m_size, n_size); i++) {
ipiv_current(i) += n1;
}
}
}
return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GETRF_SERIAL_INTERNAL_HPP_
Loading

0 comments on commit 226d4ac

Please sign in to comment.