Skip to content

Commit

Permalink
fix: use view size_type as a return type of iamax
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Oct 28, 2024
1 parent 5b53a2a commit b0ab297
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
9 changes: 5 additions & 4 deletions batched/dense/impl/KokkosBatched_Iamax_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
namespace KokkosBatched {

template <typename XViewType>
KOKKOS_INLINE_FUNCTION int SerialIamax::invoke(const XViewType &x) {
KOKKOS_INLINE_FUNCTION typename XViewType::size_type SerialIamax::invoke(const XViewType &x) {
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::iamax: XViewType is not a Kokkos::View.");
if (x.extent(0) <= 0) return -1;
if (x.extent(0) == 1) return 0;
return KokkosBatched::Impl::SerialIamaxInternal::invoke(x.extent(0), x.data(), x.stride(0));
if (x.extent(0) <= 1) return 0;
using size_type = typename XViewType::size_type;
using value_type = typename XViewType::non_const_value_type;
return KokkosBatched::Impl::SerialIamaxInternal::invoke<size_type, value_type>(x.extent(0), x.data(), x.stride(0));
}

} // namespace KokkosBatched
Expand Down
15 changes: 8 additions & 7 deletions batched/dense/impl/KokkosBatched_Iamax_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,20 @@ namespace Impl {
/// ========================

struct SerialIamaxInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const ValueType *KOKKOS_RESTRICT x, const int xs0);
template <typename IndexType, typename ValueType>
KOKKOS_INLINE_FUNCTION static IndexType invoke(const int n, const ValueType *KOKKOS_RESTRICT x, const int xs0);
};

template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialIamaxInternal::invoke(const int n, const ValueType *KOKKOS_RESTRICT x, const int xs0) {
template <typename IndexType, typename ValueType>
KOKKOS_INLINE_FUNCTION IndexType SerialIamaxInternal::invoke(const int n, const ValueType *KOKKOS_RESTRICT x,
const int xs0) {
using ats = typename Kokkos::ArithTraits<ValueType>;
using RealType = typename ats::mag_type;

RealType amax = Kokkos::abs(x[0 * xs0]);
int imax = 0;
RealType amax = Kokkos::abs(x[0 * xs0]);
IndexType imax = 0;

for (int i = 1; i < n; ++i) {
for (IndexType i = 1; i < static_cast<IndexType>(n); ++i) {
const RealType abs_x_i = Kokkos::abs(x[i * xs0]);
if (abs_x_i > amax) {
amax = abs_x_i;
Expand Down
4 changes: 2 additions & 2 deletions batched/dense/src/KokkosBatched_Iamax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ namespace KokkosBatched {
/// \param X [in]: Input view type
///
/// \return The index of the first element having maximum absolute value
/// As well as Blas, this returns -1 (0 in Fortran) for an empty vector
/// As well as Blas, this returns 0 (0 in Fortran) for an empty vector
/// No nested parallel_for is used inside of the function.
///

struct SerialIamax {
template <typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const XViewType &x);
KOKKOS_INLINE_FUNCTION static typename XViewType::size_type invoke(const XViewType &x);
};
} // namespace KokkosBatched

Expand Down
6 changes: 3 additions & 3 deletions batched/dense/unit_test/Test_Batched_SerialIamax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct Functor_BatchedSerialIamax {
void operator()(const int k) const {
auto sub_x = Kokkos::subview(m_x, k, Kokkos::ALL());
auto iamax = KokkosBatched::SerialIamax::invoke(sub_x);
m_r(k) = iamax;
m_r(k) = static_cast<int>(iamax);
}

inline void run() {
Expand Down Expand Up @@ -198,9 +198,9 @@ void impl_test_batched_iamax(const std::size_t N, const std::size_t BlkSize) {
// Reference
auto h_iamax_ref = Kokkos::create_mirror_view(iamax_ref);
if (BlkSize == 0) {
// As well as blas, we store -1 (0 in Fortran) for empty matrix
// As well as blas, we store 0 (0 in Fortran) for empty matrix
for (std::size_t k = 0; k < N; k++) {
h_iamax_ref(k) = -1;
h_iamax_ref(k) = 0;
}
} else {
auto h_A = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A);
Expand Down

0 comments on commit b0ab297

Please sign in to comment.