Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement batched serial laswp #2395

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions batched/dense/impl/KokkosBatched_Laswp_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_LASWP_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_LASWP_SERIAL_IMPL_HPP_

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Laswp_Serial_Internal.hpp"

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {
namespace Impl {

template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int checkLaswpInput(const PivViewType &piv, const AViewType &A) {
static_assert(Kokkos::is_view_v<PivViewType>, "KokkosBatched::laswp: PivViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::laswp: AViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 1 || AViewType::rank == 2, "KokkosBatched::laswp: AViewType must have rank 1 or 2.");
static_assert(PivViewType::rank == 1, "KokkosBatched::laswp: PivViewType must have rank 1.");

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int npiv = piv.extent(0);
const int lda = A.extent(0);
if (npiv > lda) {
Kokkos::printf(
"KokkosBatched::laswp: the dimension of the ipiv array must "
"satisfy ipiv.extent(0) <= A.extent(0): ipiv: %d, A: "
"%d \n",
npiv, lda);
return 1;
}
#endif
return 0;
}
} // namespace Impl

///
/// Serial Internal Impl
/// ========================

///
//// Forward pivot apply
///

template <>
struct SerialLaswp<Direct::Forward> {
template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A) {
auto info = KokkosBatched::Impl::checkLaswpInput(piv, A);
if (info) return info;

if constexpr (AViewType::rank == 1) {
const int plen = piv.extent(0), ps0 = piv.stride(0), as0 = A.stride(0);
return KokkosBatched::Impl::SerialLaswpVectorForwardInternal::invoke(plen, piv.data(), ps0, A.data(), as0);
} else if constexpr (AViewType::rank == 2) {
// row permutation
const int plen = piv.extent(0), ps0 = piv.stride(0), n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1);
return KokkosBatched::Impl::SerialLaswpMatrixForwardInternal::invoke(n, plen, piv.data(), ps0, A.data(), as0,
as1);
}
return 0;
}
};

///
/// Backward pivot apply
///

template <>
struct SerialLaswp<Direct::Backward> {
template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A) {
auto info = KokkosBatched::Impl::checkLaswpInput(piv, A);
if (info) return info;

if constexpr (AViewType::rank == 1) {
const int plen = piv.extent(0), ps0 = piv.stride(0), as0 = A.stride(0);
return KokkosBatched::Impl::SerialLaswpVectorBackwardInternal::invoke(plen, piv.data(), ps0, A.data(), as0);
} else if constexpr (AViewType::rank == 2) {
// row permutation
const int plen = piv.extent(0), ps0 = piv.stride(0), n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1);
return KokkosBatched::Impl::SerialLaswpMatrixBackwardInternal::invoke(n, plen, piv.data(), ps0, A.data(), as0,
as1);
}
return 0;
}
};
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_LASWP_SERIAL_IMPL_HPP_
150 changes: 150 additions & 0 deletions batched/dense/impl/KokkosBatched_Laswp_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_LASWP_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_LASWP_SERIAL_INTERNAL_HPP_

/// \author Yuuichi Asahi ([email protected])

#include "KokkosBatched_Util.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ========================

///
//// Forward pivot apply
///

struct SerialLaswpVectorForwardInternal {
template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
for (int i = 0; i < plen; ++i) {
const int piv = p[i * ps0];
if (piv != i) {
const int idx_i = i * as0, idx_p = piv * as0;
const ValueType tmp = A[idx_i];
A[idx_i] = A[idx_p];
A[idx_p] = tmp;
}
}
return 0;
}
};

struct SerialLaswpMatrixForwardInternal {
template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
if (as0 <= as1) {
// LayoutLeft like
for (int j = 0; j < n; j++) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
for (int i = 0; i < plen; ++i) {
const int piv = p[i * ps0];
if (piv != i) {
const int idx_i = i * as0, idx_p = piv * as0;
const ValueType tmp = A_at_j[idx_i];
A_at_j[idx_i] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
}
} else {
// LayoutRight like
for (int i = 0; i < plen; ++i) {
const int piv = p[i * ps0];
if (piv != i) {
const int idx_i = i * as0, idx_p = piv * as0;
for (int j = 0; j < n; j++) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
const ValueType tmp = A_at_j[idx_i];
A_at_j[idx_i] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
}
}
return 0;
}
};

///
/// Backward pivot apply
///

struct SerialLaswpVectorBackwardInternal {
template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
for (int i = (plen - 1); i >= 0; --i) {
const int piv = p[i * ps0];
if (piv != i) {
const int idx_i = i * as0, idx_p = piv * as0;
const ValueType tmp = A[idx_i];
A[idx_i] = A[idx_p];
A[idx_p] = tmp;
}
}
return 0;
}
};

struct SerialLaswpMatrixBackwardInternal {
template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
if (as0 <= as1) {
// LayoutLeft like
for (int j = 0; j < n; j++) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
for (int i = (plen - 1); i >= 0; --i) {
const int piv = p[i * ps0];
if (piv != i) {
const int idx_i = i * as0, idx_p = piv * as0;
const ValueType tmp = A_at_j[idx_i];
A_at_j[idx_i] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
}
} else {
// LayoutRight like
for (int i = (plen - 1); i >= 0; --i) {
const int piv = p[i * ps0];
if (piv != i) {
const int idx_i = i * as0, idx_p = piv * as0;
for (int j = 0; j < n; j++) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
const ValueType tmp = A_at_j[idx_i];
A_at_j[idx_i] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
}
}
return 0;
}
};

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_LASWP_SERIAL_INTERNAL_HPP_
52 changes: 52 additions & 0 deletions batched/dense/src/KokkosBatched_Laswp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_LASWP_HPP_
#define KOKKOSBATCHED_LASWP_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

/// \brief Serial Batched Laswp:
///
/// performs a series of row interchanges on the matrix A.
/// One row interchange is initiated for each of rows K1 through K2 of A.
///
/// \tparam PivViewType: Input type for the a superdiagonal matrix, needs to
/// be a 1D view
/// \tparam AViewType: Input type for the vector or matrix, needs to be a 1D or
/// 2D view
///
/// \param piv [in]: The pivot indices; for 0 <= i < N, row i of the
/// matrix was interchanged with row piv(i).
/// \param A [inout]: A is a lda by n matrix. The matrix of column dimension N
/// to which the row interchanges will be applied.
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgDirect>
struct SerialLaswp {
template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A);
};
} // namespace KokkosBatched

#include "../impl/KokkosBatched_Laswp_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_LASWP_HPP_
1 change: 1 addition & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "Test_Batched_SerialPbtrf.hpp"
#include "Test_Batched_SerialPbtrf_Real.hpp"
#include "Test_Batched_SerialPbtrf_Complex.hpp"
#include "Test_Batched_SerialLaswp.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
Loading
Loading