Skip to content

Commit

Permalink
unuse getrf in the getrs analytical test
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Jan 20, 2025
1 parent f99de7d commit 9a81eee
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions batched/dense/unit_test/Test_Batched_SerialGetrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct ParamTag {
using trans = T;
};

template <typename DeviceType, typename AViewType, typename PivViewType, typename AlgoTagType>
template <typename DeviceType, typename AViewType, typename PivViewType>
struct Functor_BatchedSerialGetrf {
using execution_space = typename DeviceType::execution_space;
AViewType m_a;
Expand All @@ -47,7 +47,7 @@ struct Functor_BatchedSerialGetrf {
auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL());
auto ipiv = Kokkos::subview(m_ipiv, k, Kokkos::ALL());

KokkosBatched::SerialGetrf<AlgoTagType>::invoke(aa, ipiv);
KokkosBatched::SerialGetrf<Algo::Getrf::Unblocked>::invoke(aa, ipiv);
}

inline void run() {
Expand Down Expand Up @@ -137,6 +137,10 @@ struct Functor_BatchedSerialGemv {
/// This corresponds to the following system of equations:
/// x0 + x1 = 2
/// x0 - x1 = 0
/// We confirm this with the factorized matrix LU and pivot given by
/// LU: [[1, 1],
/// [1, -2]]
/// piv: [0, 1]
///
/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix)
/// \param k [in] Number of superdiagonals or subdiagonals of matrix A
Expand All @@ -150,35 +154,33 @@ void impl_test_batched_getrs_analytical(const int N) {
using PivView2DType = Kokkos::View<int **, LayoutType, DeviceType>;

const int BlkSize = 2;
View3DType A("A", N, BlkSize, BlkSize), ref("Ref", N, BlkSize, BlkSize);
View3DType lu("lu", N, BlkSize, BlkSize); // Factorized
View2DType x("x", N, BlkSize), y("y", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions
View3DType LU("LU", N, BlkSize, BlkSize); // Factorized matrix of A
View2DType x("x", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions
PivView2DType ipiv("ipiv", N, BlkSize);

auto h_A = Kokkos::create_mirror_view(A);
auto h_LU = Kokkos::create_mirror_view(LU);
auto h_ipiv = Kokkos::create_mirror_view(ipiv);
auto h_x = Kokkos::create_mirror_view(x);
auto h_x_ref = Kokkos::create_mirror_view(x_ref);
Kokkos::deep_copy(h_A, 1.0);
Kokkos::deep_copy(h_LU, 1.0);
for (int ib = 0; ib < N; ib++) {
h_A(ib, 1, 1) = -1.0;
h_LU(ib, 1, 1) = -2.0;
h_ipiv(ib, 0) = 0;
h_ipiv(ib, 1) = 1;

h_x(ib, 0) = 2;
h_x(ib, 1) = 0;
h_x_ref(ib, 0) = 1;
h_x_ref(ib, 1) = 1;
}

Kokkos::fence();

Kokkos::deep_copy(A, h_A);
Kokkos::deep_copy(LU, h_LU);
Kokkos::deep_copy(ipiv, h_ipiv);
Kokkos::deep_copy(x, h_x);

// getrf to factorize matrix A = P * L * U
Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType, AlgoTagType>(A, ipiv).run();

// getrs (Note, LU is a factorized matrix of A)
auto info = Functor_BatchedSerialGetrs<DeviceType, View3DType, PivView2DType, View2DType, ParamTagType, AlgoTagType>(
A, ipiv, x)
LU, ipiv, x)
.run();

Kokkos::fence();
Expand Down Expand Up @@ -226,7 +228,7 @@ void impl_test_batched_getrs(const int N, const int BlkSize) {
Kokkos::deep_copy(b, x);

// getrf to factorize matrix A = P * L * U
Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType, AlgoTagType>(LU, ipiv).run();
Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType>(LU, ipiv).run();

// getrs (Note, LU is a factorized matrix of A)
auto info = Functor_BatchedSerialGetrs<DeviceType, View3DType, PivView2DType, View2DType, ParamTagType, AlgoTagType>(
Expand Down Expand Up @@ -317,19 +319,19 @@ TEST_F(TestCategory, test_batched_getrs_t_double) {
#endif

#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT)
TEST_F(TestCategory, test_batched_getrs_nt_cfloat) {
TEST_F(TestCategory, test_batched_getrs_nt_fcomplex) {
using param_tag_type = ::Test::Getrs::ParamTag<Trans::NoTranspose>;
using algo_tag_type = typename Algo::Getrs::Unblocked;

test_batched_getrs<TestDevice, Kokkos::complex<float>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_getrs_t_cfloat) {
TEST_F(TestCategory, test_batched_getrs_t_fcomplex) {
using param_tag_type = ::Test::Getrs::ParamTag<Trans::Transpose>;
using algo_tag_type = typename Algo::Getrs::Unblocked;

test_batched_getrs<TestDevice, Kokkos::complex<float>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_getrs_c_cfloat) {
TEST_F(TestCategory, test_batched_getrs_c_fcomplex) {
using param_tag_type = ::Test::Getrs::ParamTag<Trans::ConjTranspose>;
using algo_tag_type = typename Algo::Getrs::Unblocked;

Expand All @@ -338,19 +340,19 @@ TEST_F(TestCategory, test_batched_getrs_c_cfloat) {
#endif

#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE)
TEST_F(TestCategory, test_batched_getrs_nt_dfloat) {
TEST_F(TestCategory, test_batched_getrs_nt_dcomplex) {
using param_tag_type = ::Test::Getrs::ParamTag<Trans::NoTranspose>;
using algo_tag_type = typename Algo::Getrs::Unblocked;

test_batched_getrs<TestDevice, Kokkos::complex<double>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_getrs_t_dfloat) {
TEST_F(TestCategory, test_batched_getrs_t_dcomplex) {
using param_tag_type = ::Test::Getrs::ParamTag<Trans::Transpose>;
using algo_tag_type = typename Algo::Getrs::Unblocked;

test_batched_getrs<TestDevice, Kokkos::complex<double>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_getrs_c_dfloat) {
TEST_F(TestCategory, test_batched_getrs_c_dcomplex) {
using param_tag_type = ::Test::Getrs::ParamTag<Trans::ConjTranspose>;
using algo_tag_type = typename Algo::Getrs::Unblocked;

Expand Down

0 comments on commit 9a81eee

Please sign in to comment.