Skip to content

Commit

Permalink
implement batched serial getrs (#2483)
Browse files Browse the repository at this point in the history
* implement batched serial getrs

Signed-off-by: Yuuichi Asahi <[email protected]>

* unuse getrf in the getrs analytical test

Signed-off-by: Yuuichi Asahi <[email protected]>

---------

Signed-off-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jan 20, 2025
1 parent af0755b commit 834f202
Show file tree
Hide file tree
Showing 6 changed files with 569 additions and 0 deletions.
76 changes: 76 additions & 0 deletions batched/dense/impl/KokkosBatched_Getrs_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GETRS_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_GETRS_SERIAL_IMPL_HPP_

#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Getrs_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int checkGetrsInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const BViewType &b) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::getrs: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>, "KokkosBatched::getrs: BViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::getrs: AViewType must have rank 2.");
static_assert(BViewType::rank == 1, "KokkosBatched::getrs: BViewType must have rank 1.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int lda = A.extent(0), n = A.extent(1);
if (lda < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::getrs: the leading dimension of the array A must "
"satisfy lda >= max(1, n): A: "
"%d "
"x %d \n",
lda, n);
return 1;
}

const int ldb = b.extent(0);
if (ldb < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::getrs: the leading dimension of the array b must "
"satisfy ldb >= max(1, n): b: %d, A: "
"%d "
"x %d \n",
ldb, lda, n);
return 1;
}
#endif
return 0;
}

} // namespace Impl

template <typename ArgTrans>
struct SerialGetrs<ArgTrans, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
// quick return if possible
if (A.extent(1) == 0) return 0;

auto info = Impl::checkGetrsInput(A, b);
if (info) return info;

return Impl::SerialGetrsInternal<ArgTrans, Algo::Getrs::Unblocked>::invoke(A, piv, b);
}
};

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_
79 changes: 79 additions & 0 deletions batched/dense/impl/KokkosBatched_Getrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GETRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_GETRS_SERIAL_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>

namespace KokkosBatched {
namespace Impl {

template <typename ArgTrans, typename ArgAlgo>
struct SerialGetrsInternal {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b);
};

//// Non-transpose ////
template <>
struct SerialGetrsInternal<Trans::NoTranspose, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialLaswp<Direct::Forward>::invoke(piv, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Trsm::Unblocked>::invoke(1.0, A, b);

return 0;
}
};

//// Transpose ////
template <>
struct SerialGetrsInternal<Trans::Transpose, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, Diag::NonUnit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialLaswp<Direct::Backward>::invoke(piv, b);

return 0;
}
};

//// Conj-Transpose ////
template <>
struct SerialGetrsInternal<Trans::ConjTranspose, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::ConjTranspose, Diag::NonUnit,
Algo::Trsm::Unblocked>::invoke(1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::ConjTranspose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialLaswp<Direct::Backward>::invoke(piv, b);

return 0;
}
};

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GETRS_SERIAL_INTERNAL_HPP_
51 changes: 51 additions & 0 deletions batched/dense/src/KokkosBatched_Getrs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_GETRS_HPP_
#define KOKKOSBATCHED_GETRS_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

/// \brief Serial Batched Getrs:
/// Solve a system of linear equations
/// A * x = b or A**T * x = b
/// with a general N-by-N matrix A using LU factorization computed
/// by Getrf.
/// \tparam AViewType: Input type for the matrix, needs to be a 2D view
/// \tparam PivViewType: Input type for the pivot indices, needs to be a 1D view
/// \tparam BViewType: Input type for the right-hand side and the solution,
/// needs to be a 1D view
///
/// \param A [in]: A is a m by n general matrix, a rank 2 view
/// \param piv [in]: On exit, the pivot indices, a rank 1 view
/// \param B [inout]: right-hand side and the solution, a rank 1 view
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgTrans, typename ArgAlgo>
struct SerialGetrs {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b);
};
} // namespace KokkosBatched

#include "KokkosBatched_Getrs_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_GETRF_HPP_
1 change: 1 addition & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "Test_Batched_SerialLaswp.hpp"
#include "Test_Batched_SerialIamax.hpp"
#include "Test_Batched_SerialGetrf.hpp"
#include "Test_Batched_SerialGetrs.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
Loading

0 comments on commit 834f202

Please sign in to comment.