Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement batched serial iamax #2399

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions batched/dense/impl/KokkosBatched_Iamax_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_IAMAX_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_IAMAX_SERIAL_IMPL_HPP_

/// \author Yuuichi Asahi ([email protected])

#include "KokkosBatched_Iamax_Serial_Internal.hpp"

namespace KokkosBatched {

template <typename XViewType>
KOKKOS_INLINE_FUNCTION typename XViewType::size_type SerialIamax::invoke(const XViewType &x) {
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::iamax: XViewType is not a Kokkos::View.");
if (x.extent(0) <= 1) return 0;
using size_type = typename XViewType::size_type;
using value_type = typename XViewType::non_const_value_type;
return KokkosBatched::Impl::SerialIamaxInternal::invoke<size_type, value_type>(x.extent(0), x.data(), x.stride(0));
}

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_IAMAX_SERIAL_IMPL_HPP_
60 changes: 60 additions & 0 deletions batched/dense/impl/KokkosBatched_Iamax_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_IAMAX_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_IAMAX_SERIAL_INTERNAL_HPP_

/// \author Yuuichi Asahi ([email protected])

#include <Kokkos_Core.hpp>
#include "KokkosBatched_Util.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ========================

struct SerialIamaxInternal {
template <typename IndexType, typename ValueType>
KOKKOS_INLINE_FUNCTION static IndexType invoke(const int n, const ValueType *KOKKOS_RESTRICT x, const int xs0);
};

template <typename IndexType, typename ValueType>
KOKKOS_INLINE_FUNCTION IndexType SerialIamaxInternal::invoke(const int n, const ValueType *KOKKOS_RESTRICT x,
const int xs0) {
using ats = typename Kokkos::ArithTraits<ValueType>;
using RealType = typename ats::mag_type;

RealType amax = Kokkos::abs(x[0 * xs0]);
IndexType imax = 0;

for (IndexType i = 1; i < static_cast<IndexType>(n); ++i) {
const RealType abs_x_i = Kokkos::abs(x[i * xs0]);
if (abs_x_i > amax) {
amax = abs_x_i;
imax = i;
}
}

return imax;
};

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_IAMAX_SERIAL_INTERNAL_HPP_
43 changes: 43 additions & 0 deletions batched/dense/src/KokkosBatched_Iamax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_IAMAX_HPP_
#define KOKKOSBATCHED_IAMAX_HPP_

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

/// \brief Serial Batched Iamax:
/// Iamax finds the index of the first element having maximum absolute value.
///
/// \tparam XViewType: Input view type, needs to be a 1D view
///
/// \param X [in]: Input view type
///
/// \return The index of the first element having maximum absolute value
/// As well as Blas, this returns 0 (0 in Fortran) for an empty vector
/// No nested parallel_for is used inside of the function.
///

struct SerialIamax {
template <typename XViewType>
KOKKOS_INLINE_FUNCTION static typename XViewType::size_type invoke(const XViewType &x);
};
} // namespace KokkosBatched

#include "KokkosBatched_Iamax_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_IAMAX_HPP_
1 change: 1 addition & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include "Test_Batched_SerialPbtrs_Real.hpp"
#include "Test_Batched_SerialPbtrs_Complex.hpp"
#include "Test_Batched_SerialLaswp.hpp"
#include "Test_Batched_SerialIamax.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
Loading
Loading