From f99de7db56ba4c42c113bd76dc732c0e63214512 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 14 Jan 2025 14:17:40 +0900 Subject: [PATCH 1/2] implement batched serial getrs Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Getrs_Serial_Impl.hpp | 76 ++++ .../KokkosBatched_Getrs_Serial_Internal.hpp | 79 ++++ batched/dense/src/KokkosBatched_Getrs.hpp | 51 +++ .../dense/unit_test/Test_Batched_Dense.hpp | 1 + .../unit_test/Test_Batched_SerialGetrs.hpp | 359 ++++++++++++++++++ blas/impl/KokkosBlas_util.hpp | 1 + 6 files changed, 567 insertions(+) create mode 100644 batched/dense/impl/KokkosBatched_Getrs_Serial_Impl.hpp create mode 100644 batched/dense/impl/KokkosBatched_Getrs_Serial_Internal.hpp create mode 100644 batched/dense/src/KokkosBatched_Getrs.hpp create mode 100644 batched/dense/unit_test/Test_Batched_SerialGetrs.hpp diff --git a/batched/dense/impl/KokkosBatched_Getrs_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Getrs_Serial_Impl.hpp new file mode 100644 index 0000000000..c8efad10e8 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Getrs_Serial_Impl.hpp @@ -0,0 +1,76 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOSBATCHED_GETRS_SERIAL_IMPL_HPP_ +#define KOKKOSBATCHED_GETRS_SERIAL_IMPL_HPP_ + +#include +#include "KokkosBatched_Getrs_Serial_Internal.hpp" + +namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkGetrsInput([[maybe_unused]] const AViewType &A, + [[maybe_unused]] const BViewType &b) { + static_assert(Kokkos::is_view_v, "KokkosBatched::getrs: AViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::getrs: BViewType is not a Kokkos::View."); + static_assert(AViewType::rank == 2, "KokkosBatched::getrs: AViewType must have rank 2."); + static_assert(BViewType::rank == 1, "KokkosBatched::getrs: BViewType must have rank 1."); +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + const int lda = A.extent(0), n = A.extent(1); + if (lda < Kokkos::max(1, n)) { + Kokkos::printf( + "KokkosBatched::getrs: the leading dimension of the array A must " + "satisfy lda >= max(1, n): A: " + "%d " + "x %d \n", + lda, n); + return 1; + } + + const int ldb = b.extent(0); + if (ldb < Kokkos::max(1, n)) { + Kokkos::printf( + "KokkosBatched::getrs: the leading dimension of the array b must " + "satisfy ldb >= max(1, n): b: %d, A: " + "%d " + "x %d \n", + ldb, lda, n); + return 1; + } +#endif + return 0; +} + +} // namespace Impl + +template +struct SerialGetrs { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) { + // quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = Impl::checkGetrsInput(A, b); + if (info) return info; + + return Impl::SerialGetrsInternal::invoke(A, piv, b); + } +}; + +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_ diff --git a/batched/dense/impl/KokkosBatched_Getrs_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Getrs_Serial_Internal.hpp new file mode 100644 index 0000000000..bc6a981fc9 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Getrs_Serial_Internal.hpp @@ -0,0 +1,79 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOSBATCHED_GETRS_SERIAL_INTERNAL_HPP_ +#define KOKKOSBATCHED_GETRS_SERIAL_INTERNAL_HPP_ + +#include + +namespace KokkosBatched { +namespace Impl { + +template +struct SerialGetrsInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b); +}; + +//// Non-transpose //// +template <> +struct SerialGetrsInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) { + KokkosBatched::SerialLaswp::invoke(piv, b); + KokkosBatched::SerialTrsm::invoke( + 1.0, A, b); + KokkosBatched::SerialTrsm::invoke(1.0, A, b); + + return 0; + } +}; + +//// Transpose //// +template <> +struct SerialGetrsInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) { + KokkosBatched::SerialTrsm::invoke( + 1.0, A, b); + KokkosBatched::SerialTrsm::invoke( + 1.0, A, b); + KokkosBatched::SerialLaswp::invoke(piv, b); + + return 0; + } +}; + +//// Conj-Transpose //// +template <> +struct SerialGetrsInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) { + KokkosBatched::SerialTrsm::invoke(1.0, A, b); + KokkosBatched::SerialTrsm::invoke( + 1.0, A, b); + KokkosBatched::SerialLaswp::invoke(piv, b); + + return 0; + } +}; + +} // namespace Impl +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_GETRS_SERIAL_INTERNAL_HPP_ diff --git a/batched/dense/src/KokkosBatched_Getrs.hpp b/batched/dense/src/KokkosBatched_Getrs.hpp new file mode 100644 index 0000000000..3495347f48 --- /dev/null +++ b/batched/dense/src/KokkosBatched_Getrs.hpp @@ -0,0 +1,51 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#ifndef KOKKOSBATCHED_GETRS_HPP_ +#define KOKKOSBATCHED_GETRS_HPP_ + +#include + +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) + +namespace KokkosBatched { + +/// \brief Serial Batched Getrs: +/// Solve a system of linear equations +/// A * x = b or A**T * x = b +/// with a general N-by-N matrix A using LU factorization computed +/// by Getrf. +/// \tparam AViewType: Input type for the matrix, needs to be a 2D view +/// \tparam PivViewType: Input type for the pivot indices, needs to be a 1D view +/// \tparam BViewType: Input type for the right-hand side and the solution, +/// needs to be a 1D view +/// +/// \param A [in]: A is a m by n general matrix, a rank 2 view +/// \param piv [in]: On exit, the pivot indices, a rank 1 view +/// \param B [inout]: right-hand side and the solution, a rank 1 view +/// +/// No nested parallel_for is used inside of the function. +/// + +template +struct SerialGetrs { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b); +}; +} // namespace KokkosBatched + +#include "KokkosBatched_Getrs_Serial_Impl.hpp" + +#endif // KOKKOSBATCHED_GETRF_HPP_ diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index 37673e1a5e..551f200101 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -64,6 +64,7 @@ #include "Test_Batched_SerialLaswp.hpp" #include "Test_Batched_SerialIamax.hpp" #include "Test_Batched_SerialGetrf.hpp" +#include "Test_Batched_SerialGetrs.hpp" // Team Kernels #include "Test_Batched_TeamAxpy.hpp" diff --git a/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp b/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp new file mode 100644 index 0000000000..da499139b5 --- /dev/null +++ b/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp @@ -0,0 +1,359 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) +#include +#include +#include +#include +#include +#include +#include +#include "Test_Batched_DenseUtils.hpp" + +using namespace KokkosBatched; + +namespace Test { +namespace Getrs { + +template +struct ParamTag { + using trans = T; +}; + +template +struct Functor_BatchedSerialGetrf { + using execution_space = typename DeviceType::execution_space; + AViewType m_a; + PivViewType m_ipiv; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialGetrf(const AViewType &a, const PivViewType &ipiv) : m_a(a), m_ipiv(ipiv) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int k) const { + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto ipiv = Kokkos::subview(m_ipiv, k, Kokkos::ALL()); + + KokkosBatched::SerialGetrf::invoke(aa, ipiv); + } + + inline void run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialGetrs"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::RangePolicy policy(0, m_a.extent(0)); + Kokkos::parallel_for(name.c_str(), policy, *this); + } +}; + +template +struct Functor_BatchedSerialGetrs { + using execution_space = typename DeviceType::execution_space; + AViewType m_a; + BViewType m_b; + PivViewType m_ipiv; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialGetrs(const AViewType &a, const PivViewType &ipiv, const BViewType &b) + : m_a(a), m_b(b), m_ipiv(ipiv) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const ParamTagType &, const int k, int &info) const { + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto ipiv = Kokkos::subview(m_ipiv, k, Kokkos::ALL()); + auto bb = Kokkos::subview(m_b, k, Kokkos::ALL()); + + info += KokkosBatched::SerialGetrs::invoke(aa, ipiv, bb); + } + + inline int run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialGetrs"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + int info_sum = 0; + Kokkos::Profiling::pushRegion(name.c_str()); + Kokkos::RangePolicy policy(0, m_b.extent(0)); + Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum); + Kokkos::Profiling::popRegion(); + return info_sum; + } +}; + +template +struct Functor_BatchedSerialGemv { + using execution_space = typename DeviceType::execution_space; + AViewType m_a; + xViewType m_x; + yViewType m_y; + ScalarType m_alpha, m_beta; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta, + const yViewType &y) + : m_a(a), m_x(x), m_y(y), m_alpha(alpha), m_beta(beta) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const ParamTagType &, const int k) const { + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto xx = Kokkos::subview(m_x, k, Kokkos::ALL()); + auto yy = Kokkos::subview(m_y, k, Kokkos::ALL()); + + KokkosBlas::SerialGemv::invoke(m_alpha, aa, xx, m_beta, yy); + } + + inline void run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialGetrs"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::RangePolicy policy(0, m_x.extent(0)); + Kokkos::parallel_for(name.c_str(), policy, *this); + } +}; + +/// \brief Implementation details of batched getrs test +/// Confirm A * x = b, where +/// A: [[1, 1], +/// [1, -1]] +/// b: [2, 0] +/// x: [1, 1] +/// This corresponds to the following system of equations: +/// x0 + x1 = 2 +/// x0 - x1 = 0 +/// +/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix) +/// \param k [in] Number of superdiagonals or subdiagonals of matrix A +/// \param BlkSize [in] Block size of matrix A +template +void impl_test_batched_getrs_analytical(const int N) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + using View2DType = Kokkos::View; + using View3DType = Kokkos::View; + using PivView2DType = Kokkos::View; + + const int BlkSize = 2; + View3DType A("A", N, BlkSize, BlkSize), ref("Ref", N, BlkSize, BlkSize); + View3DType lu("lu", N, BlkSize, BlkSize); // Factorized + View2DType x("x", N, BlkSize), y("y", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions + PivView2DType ipiv("ipiv", N, BlkSize); + + auto h_A = Kokkos::create_mirror_view(A); + auto h_x = Kokkos::create_mirror_view(x); + auto h_x_ref = Kokkos::create_mirror_view(x_ref); + Kokkos::deep_copy(h_A, 1.0); + for (int ib = 0; ib < N; ib++) { + h_A(ib, 1, 1) = -1.0; + + h_x(ib, 0) = 2; + h_x(ib, 1) = 0; + h_x_ref(ib, 0) = 1; + h_x_ref(ib, 1) = 1; + } + + Kokkos::fence(); + + Kokkos::deep_copy(A, h_A); + Kokkos::deep_copy(x, h_x); + + // getrf to factorize matrix A = P * L * U + Functor_BatchedSerialGetrf(A, ipiv).run(); + + // getrs (Note, LU is a factorized matrix of A) + auto info = Functor_BatchedSerialGetrs( + A, ipiv, x) + .run(); + + Kokkos::fence(); + EXPECT_EQ(info, 0); + + // this eps is about 10^-14 + RealType eps = 1.0e3 * ats::epsilon(); + + // Check if x = [1, 1] + Kokkos::deep_copy(h_x, x); + for (int ib = 0; ib < N; ib++) { + for (int i = 0; i < BlkSize; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + } + } +} + +/// \brief Implementation details of batched getrs test +/// +/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix) +/// \param k [in] Number of superdiagonals or subdiagonals of matrix A +/// \param BlkSize [in] Block size of matrix A +template +void impl_test_batched_getrs(const int N, const int BlkSize) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + using View2DType = Kokkos::View; + using View3DType = Kokkos::View; + using PivView2DType = Kokkos::View; + + View3DType A("A", N, BlkSize, BlkSize), ref("Ref", N, BlkSize, BlkSize); + View3DType LU("LU", N, BlkSize, BlkSize); // Factorized + View2DType x("x", N, BlkSize), y("y", N, BlkSize), b("b", N, BlkSize); // Solutions + PivView2DType ipiv("ipiv", N, BlkSize); + + using execution_space = typename DeviceType::execution_space; + Kokkos::Random_XorShift64_Pool rand_pool(13718); + ScalarType randStart, randEnd; + + // Initialize A_reconst with random matrix + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(A, rand_pool, randStart, randEnd); + Kokkos::fill_random(x, rand_pool, randStart, randEnd); + Kokkos::deep_copy(LU, A); + Kokkos::deep_copy(b, x); + + // getrf to factorize matrix A = P * L * U + Functor_BatchedSerialGetrf(LU, ipiv).run(); + + // getrs (Note, LU is a factorized matrix of A) + auto info = Functor_BatchedSerialGetrs( + LU, ipiv, x) + .run(); + + Kokkos::fence(); + EXPECT_EQ(info, 0); + + // Gemv to compute A*x, this should be identical to b + Functor_BatchedSerialGemv(1.0, A, x, 0.0, y) + .run(); + + Kokkos::fence(); + + // this eps is about 10^-14 + RealType eps = 1.0e3 * ats::epsilon(); + + // Check if A * x = b + auto h_y = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), y); + auto h_b = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), b); + for (int ib = 0; ib < N; ib++) { + for (int i = 0; i < BlkSize; i++) { + EXPECT_NEAR_KK(h_y(ib, i), h_b(ib, i), eps); + } + } +} + +} // namespace Getrs +} // namespace Test + +template +int test_batched_getrs() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + using LayoutType = Kokkos::LayoutLeft; + Test::Getrs::impl_test_batched_getrs_analytical(1); + Test::Getrs::impl_test_batched_getrs_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Getrs::impl_test_batched_getrs(1, i); + Test::Getrs::impl_test_batched_getrs(2, i); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + using LayoutType = Kokkos::LayoutRight; + Test::Getrs::impl_test_batched_getrs_analytical(1); + Test::Getrs::impl_test_batched_getrs_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Getrs::impl_test_batched_getrs(1, i); + Test::Getrs::impl_test_batched_getrs(2, i); + } + } +#endif + + return 0; +} + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_F(TestCategory, test_batched_getrs_nt_float) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs(); +} +TEST_F(TestCategory, test_batched_getrs_t_float) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_F(TestCategory, test_batched_getrs_nt_double) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs(); +} +TEST_F(TestCategory, test_batched_getrs_t_double) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_F(TestCategory, test_batched_getrs_nt_cfloat) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, test_batched_getrs_t_cfloat) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, test_batched_getrs_c_cfloat) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs, param_tag_type, algo_tag_type>(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_F(TestCategory, test_batched_getrs_nt_dfloat) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, test_batched_getrs_t_dfloat) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, test_batched_getrs_c_dfloat) { + using param_tag_type = ::Test::Getrs::ParamTag; + using algo_tag_type = typename Algo::Getrs::Unblocked; + + test_batched_getrs, param_tag_type, algo_tag_type>(); +} +#endif diff --git a/blas/impl/KokkosBlas_util.hpp b/blas/impl/KokkosBlas_util.hpp index 657e52cf77..00f5764a57 100644 --- a/blas/impl/KokkosBlas_util.hpp +++ b/blas/impl/KokkosBlas_util.hpp @@ -104,6 +104,7 @@ struct Algo { using Pttrf = Level3; using Pttrs = Level3; using Getrf = Level3; + using Getrs = Level3; struct Level2 { struct Unblocked {}; From 9a81eeec73c14c429850113a5547203ed86d6cce Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 21 Jan 2025 00:45:38 +0900 Subject: [PATCH 2/2] unuse getrf in the getrs analytical test Signed-off-by: Yuuichi Asahi --- .../unit_test/Test_Batched_SerialGetrs.hpp | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp b/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp index da499139b5..22f4ff58a2 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGetrs.hpp @@ -33,7 +33,7 @@ struct ParamTag { using trans = T; }; -template +template struct Functor_BatchedSerialGetrf { using execution_space = typename DeviceType::execution_space; AViewType m_a; @@ -47,7 +47,7 @@ struct Functor_BatchedSerialGetrf { auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); auto ipiv = Kokkos::subview(m_ipiv, k, Kokkos::ALL()); - KokkosBatched::SerialGetrf::invoke(aa, ipiv); + KokkosBatched::SerialGetrf::invoke(aa, ipiv); } inline void run() { @@ -137,6 +137,10 @@ struct Functor_BatchedSerialGemv { /// This corresponds to the following system of equations: /// x0 + x1 = 2 /// x0 - x1 = 0 +/// We confirm this with the factorized matrix LU and pivot given by +/// LU: [[1, 1], +/// [1, -2]] +/// piv: [0, 1] /// /// \param N [in] Batch size of RHS (banded matrix can also be batched matrix) /// \param k [in] Number of superdiagonals or subdiagonals of matrix A @@ -150,17 +154,19 @@ void impl_test_batched_getrs_analytical(const int N) { using PivView2DType = Kokkos::View; const int BlkSize = 2; - View3DType A("A", N, BlkSize, BlkSize), ref("Ref", N, BlkSize, BlkSize); - View3DType lu("lu", N, BlkSize, BlkSize); // Factorized - View2DType x("x", N, BlkSize), y("y", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions + View3DType LU("LU", N, BlkSize, BlkSize); // Factorized matrix of A + View2DType x("x", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions PivView2DType ipiv("ipiv", N, BlkSize); - auto h_A = Kokkos::create_mirror_view(A); + auto h_LU = Kokkos::create_mirror_view(LU); + auto h_ipiv = Kokkos::create_mirror_view(ipiv); auto h_x = Kokkos::create_mirror_view(x); auto h_x_ref = Kokkos::create_mirror_view(x_ref); - Kokkos::deep_copy(h_A, 1.0); + Kokkos::deep_copy(h_LU, 1.0); for (int ib = 0; ib < N; ib++) { - h_A(ib, 1, 1) = -1.0; + h_LU(ib, 1, 1) = -2.0; + h_ipiv(ib, 0) = 0; + h_ipiv(ib, 1) = 1; h_x(ib, 0) = 2; h_x(ib, 1) = 0; @@ -168,17 +174,13 @@ void impl_test_batched_getrs_analytical(const int N) { h_x_ref(ib, 1) = 1; } - Kokkos::fence(); - - Kokkos::deep_copy(A, h_A); + Kokkos::deep_copy(LU, h_LU); + Kokkos::deep_copy(ipiv, h_ipiv); Kokkos::deep_copy(x, h_x); - // getrf to factorize matrix A = P * L * U - Functor_BatchedSerialGetrf(A, ipiv).run(); - // getrs (Note, LU is a factorized matrix of A) auto info = Functor_BatchedSerialGetrs( - A, ipiv, x) + LU, ipiv, x) .run(); Kokkos::fence(); @@ -226,7 +228,7 @@ void impl_test_batched_getrs(const int N, const int BlkSize) { Kokkos::deep_copy(b, x); // getrf to factorize matrix A = P * L * U - Functor_BatchedSerialGetrf(LU, ipiv).run(); + Functor_BatchedSerialGetrf(LU, ipiv).run(); // getrs (Note, LU is a factorized matrix of A) auto info = Functor_BatchedSerialGetrs( @@ -317,19 +319,19 @@ TEST_F(TestCategory, test_batched_getrs_t_double) { #endif #if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) -TEST_F(TestCategory, test_batched_getrs_nt_cfloat) { +TEST_F(TestCategory, test_batched_getrs_nt_fcomplex) { using param_tag_type = ::Test::Getrs::ParamTag; using algo_tag_type = typename Algo::Getrs::Unblocked; test_batched_getrs, param_tag_type, algo_tag_type>(); } -TEST_F(TestCategory, test_batched_getrs_t_cfloat) { +TEST_F(TestCategory, test_batched_getrs_t_fcomplex) { using param_tag_type = ::Test::Getrs::ParamTag; using algo_tag_type = typename Algo::Getrs::Unblocked; test_batched_getrs, param_tag_type, algo_tag_type>(); } -TEST_F(TestCategory, test_batched_getrs_c_cfloat) { +TEST_F(TestCategory, test_batched_getrs_c_fcomplex) { using param_tag_type = ::Test::Getrs::ParamTag; using algo_tag_type = typename Algo::Getrs::Unblocked; @@ -338,19 +340,19 @@ TEST_F(TestCategory, test_batched_getrs_c_cfloat) { #endif #if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) -TEST_F(TestCategory, test_batched_getrs_nt_dfloat) { +TEST_F(TestCategory, test_batched_getrs_nt_dcomplex) { using param_tag_type = ::Test::Getrs::ParamTag; using algo_tag_type = typename Algo::Getrs::Unblocked; test_batched_getrs, param_tag_type, algo_tag_type>(); } -TEST_F(TestCategory, test_batched_getrs_t_dfloat) { +TEST_F(TestCategory, test_batched_getrs_t_dcomplex) { using param_tag_type = ::Test::Getrs::ParamTag; using algo_tag_type = typename Algo::Getrs::Unblocked; test_batched_getrs, param_tag_type, algo_tag_type>(); } -TEST_F(TestCategory, test_batched_getrs_c_dfloat) { +TEST_F(TestCategory, test_batched_getrs_c_dcomplex) { using param_tag_type = ::Test::Getrs::ParamTag; using algo_tag_type = typename Algo::Getrs::Unblocked;