From 0e5d2ca224f365d360f3f14278772e083d784004 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Fri, 13 Dec 2024 14:59:05 +0900 Subject: [PATCH 1/7] Add ConjTrans to Serial Gemm Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Gemm_Serial_Impl.hpp | 551 ++++++++++++++++-- .../KokkosBatched_Gemm_Serial_Internal.hpp | 34 +- .../impl/KokkosBatched_Gemm_Team_Internal.hpp | 4 +- ...okkosBatched_InnerGemmFixC_Serial_Impl.hpp | 492 ++++++++-------- .../impl/KokkosBatched_LU_Serial_Internal.hpp | 5 +- .../src/KokkosBatched_InnerGemmFixC_Decl.hpp | 12 +- .../unit_test/Test_Batched_SerialGemm.hpp | 184 +++--- .../Test_Batched_SerialGemm_Complex.hpp | 249 ++++++-- .../Test_Batched_SerialGemm_Real.hpp | 65 ++- ...osBlas2_serial_gemv_inner_multiple_dot.hpp | 17 +- blas/impl/KokkosBlas_util.hpp | 16 + common/src/KokkosKernels_BlockUtils.hpp | 6 +- sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp | 12 +- .../impl/KokkosSparse_spmv_bsrmatrix_impl.hpp | 11 +- 14 files changed, 1142 insertions(+), 516 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp index fae44c8f83..c539ffe4a2 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp @@ -16,22 +16,63 @@ #ifndef KOKKOSBATCHED_GEMM_SERIAL_IMPL_HPP #define KOKKOSBATCHED_GEMM_SERIAL_IMPL_HPP +#include "KokkosBlas_util.hpp" #include "KokkosBatched_Util.hpp" #include "KokkosBatched_Gemm_Serial_Internal.hpp" namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkGemmInput([[maybe_unused]] const AViewType &A, + [[maybe_unused]] const BViewType &B, + [[maybe_unused]] const CViewType &C) { + static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: AViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: BViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: CViewType is not a Kokkos::View."); + /* + static_assert(AViewType::rank == 1 || AViewType::rank == 2, + "KokkosBatched::gemm: AViewType must have rank 1 or 2."); + static_assert(BViewType::rank == 1 || BViewType::rank == 2, + "KokkosBatched::gemm: BViewType must have rank 1 or 2."); + static_assert(CViewType::rank == 1 || CViewType::rank == 2, + "KokkosBatched::gemm: CViewType must have rank 1 or 2."); + */ + +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + const int m = C.extent(0), n = C.extent(1); + const int lda = A.extent(0), k = A.extent(1); + const int ldb = B.extent(0); + + const int nrowa = std::is_same_v ? m : k; + const int nrowb = std::is_same_v ? k : n; + + if (lda < Kokkos::max(1, nrowa)) { + Kokkos::printf( + "KokkosBatched::gemm: leading dimension of A must not be smaller than " + "max(1, nrowa): " + "lda = %d, nrowa = %d\n", + lda, nrowa); + return 1; + } + if (ldb < Kokkos::max(1, nrowb)) { + Kokkos::printf( + "KokkosBatched::gemm: leading dimension of B must not be smaller than " + "max(1, nrowb): " + "ldb = %d, nrowb = %d\n", + ldb, nrowb); + return 1; + } + +#endif + + return 0; +} +} // namespace Impl + /// /// Serial Impl /// =========== -/// -/// Implemented: -/// NT/NT, T/NT, NT/T, T/T -/// -/// Not yet immplemented (ConjTranspose): -/// CT/NT, NT/CT, CT/CT -/// - /// /// NT/NT /// @@ -73,22 +114,36 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } /// @@ -132,22 +187,109 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// C/NT +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_CONJTRANS, MKL_NOTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_CONJTRANS, MKL_NOTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } /// @@ -191,22 +333,36 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^T + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^T + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } /// @@ -250,23 +406,330 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// C/T +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_CONJTRANS, MKL_TRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_CONJTRANS, MKL_TRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } + +/// +/// NT/C +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_NOTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_NOTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^H + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^H + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// T/C +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_TRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_TRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// C/C +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_CONJTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), + A.stride_1(), (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), + format, (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_CONJTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), + A.stride_0(), (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), + format, (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + } // namespace KokkosBatched #endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp index 1a83a27112..a66506e633 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp @@ -26,6 +26,7 @@ #include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp" namespace KokkosBatched { +namespace Impl { /// /// Serial Internal Impl @@ -33,19 +34,20 @@ namespace KokkosBatched { template struct SerialGemmInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, - const ScalarType beta, + template + KOKKOS_INLINE_FUNCTION static int invoke(OpA opA, OpB opB, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); }; template <> -template +template KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( - const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta, + OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) @@ -58,17 +60,15 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1); if (alpha != zero) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - ValueType *KOKKOS_RESTRICT pC = C; for (int p = 0; p < k; ++p) { const ValueType *KOKKOS_RESTRICT pA = A + p * as1, *KOKKOS_RESTRICT pB = B + p * bs0; for (int i = 0; i < m; ++i) { - const ValueType tA(alpha * pA[i * as0]); + const ValueType tA(alpha * opA(pA[i * as0])); #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * pB[j * bs1]; + for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * opB(pB[j * bs1]); } } } @@ -76,10 +76,11 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( - const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta, + OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) @@ -105,7 +106,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( const int mb = mbAlgo, nb = nbAlgo; for (int i = 0; i < ib; i += mb) for (int j = 0; j < jb; j += nb) - inner.serial_invoke(alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb, + inner.serial_invoke(opA, opB, alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb, (j + nb) > jb ? (jb - j) : nb, pb, CC + i * cs0 + j * cs1); }; @@ -138,6 +139,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( return 0; } +} // namespace Impl } // namespace KokkosBatched #endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp index b8647f5205..70c1ce3a03 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp @@ -122,8 +122,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( i = ij / nq * mb; j = ij % nq * nb; } - inner.serial_invoke(alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, - CC + i * cs0 + j * cs1); + inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), alpha, AA + i * as0, BB + j * bs1, + (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1); }); }; diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp index e090ce57bd..328fdbcf3d 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp @@ -28,8 +28,8 @@ namespace KokkosBatched { /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -47,16 +47,16 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - a_4p = A[i4 + p * _as1]; - b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + a_4p = opA(A[i4 + p * _as1]); + b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -115,8 +115,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -133,15 +133,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -190,8 +190,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -207,14 +207,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -253,8 +253,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -269,13 +269,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -304,8 +304,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -319,12 +319,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -343,8 +343,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -362,15 +362,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -419,8 +419,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -438,14 +438,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -484,8 +484,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -503,13 +503,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -538,8 +538,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -557,12 +557,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -584,8 +584,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -603,14 +603,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -651,8 +651,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -668,13 +668,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -707,8 +707,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -723,12 +723,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -753,8 +753,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -769,11 +769,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -790,8 +790,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -808,13 +808,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -847,8 +847,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -865,12 +865,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -895,8 +895,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -913,11 +913,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -938,8 +938,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -955,12 +955,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -987,8 +987,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1003,11 +1003,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1028,8 +1028,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1043,10 +1043,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -1061,8 +1061,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1078,11 +1078,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1102,8 +1102,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(const ScalarType a return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1119,10 +1119,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1141,8 +1141,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1156,10 +1156,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1176,8 +1176,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1191,9 +1191,9 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -1206,8 +1206,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1221,9 +1221,9 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /* */ b_p1 = B[p * _bs0 + j1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /* */ b_p1 = opB(B[p * _bs0 + j1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; } @@ -1239,8 +1239,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1254,8 +1254,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); c_00 += a_0p * b_p0; } C[0 * _cs0 + 0 * _cs1] += alpha * c_00; @@ -1264,8 +1264,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int k, @@ -1275,27 +1275,27 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(const ScalarType a switch (m) { case 5: { InnerGemmFixC<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 4: { InnerGemmFixC<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 3: { InnerGemmFixC<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 2: { InnerGemmFixC<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 1: { InnerGemmFixC<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { @@ -1307,8 +1307,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1319,52 +1319,52 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 55: { InnerGemmFixC<5, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 54: { InnerGemmFixC<5, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 53: { InnerGemmFixC<5, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 52: { InnerGemmFixC<5, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 51: { InnerGemmFixC<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 45: { InnerGemmFixC<4, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 35: { InnerGemmFixC<3, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 25: { InnerGemmFixC<2, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 15: { InnerGemmFixC<1, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1372,8 +1372,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1384,42 +1384,42 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 44: { InnerGemmFixC<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 43: { InnerGemmFixC<4, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 42: { InnerGemmFixC<4, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 41: { InnerGemmFixC<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 34: { InnerGemmFixC<3, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 24: { InnerGemmFixC<2, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 14: { InnerGemmFixC<1, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1427,8 +1427,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1439,32 +1439,32 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 33: { InnerGemmFixC<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 32: { InnerGemmFixC<3, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 31: { InnerGemmFixC<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 23: { InnerGemmFixC<2, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 13: { InnerGemmFixC<1, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1472,8 +1472,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1484,22 +1484,22 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 22: { InnerGemmFixC<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 21: { InnerGemmFixC<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 12: { InnerGemmFixC<1, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 11: { InnerGemmFixC<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } } @@ -1507,8 +1507,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1516,7 +1516,7 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType a if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 1 && n <= 1)) Kokkos::abort("InnerGemmFixC<1,1>::serial_invoke, assert failure (m<=1 && n<=1)"); - return serial_invoke(alpha, A, B, k, C); + return serial_invoke(opA, opB, alpha, A, B, k, C); ; } diff --git a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp index 52002ad473..4b7166f0b9 100644 --- a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp @@ -115,8 +115,9 @@ KOKKOS_INLINE_FUNCTION int SerialLU_Internal::invoke( trsm_run.serial_invoke(Ap, pb, m_abr, Ap + mb * as0); // gemm update - SerialGemmInternal::invoke(m_abr, n_abr, pb, minus_one, Ap + mb * as0, as0, as1, - Ap + mb * as1, as0, as1, one, Ap + mb * as0 + mb * as1, as0, as1); + Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m_abr, n_abr, pb, minus_one, Ap + mb * as0, as0, as1, + Ap + mb * as1, as0, as1, one, Ap + mb * as0 + mb * as1, as0, as1); } }; diff --git a/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp b/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp index 31ba2a03d9..ca55816fe4 100644 --- a/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp +++ b/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp @@ -29,20 +29,20 @@ struct InnerGemmFixC { : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} // serial rank update - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C); // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int k, /**/ ValueType *KOKKOS_RESTRICT C); // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C); diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp index 0b2ed4a162..776d83afe1 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp @@ -34,95 +34,98 @@ namespace Gemm { template struct ParamTag { - typedef TA transA; - typedef TB transB; + using transA = TA; + using transB = TB; }; template struct Functor_TestBatchedSerialGemm { using execution_space = typename DeviceType::execution_space; - ViewType _a, _b, _c; - - ScalarType _alpha, _beta; + ViewType m_a, m_b, m_c; + ScalarType m_alpha, m_beta; KOKKOS_INLINE_FUNCTION Functor_TestBatchedSerialGemm(const ScalarType alpha, const ViewType &a, const ViewType &b, const ScalarType beta, const ViewType &c) - : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + : m_a(a), m_b(b), m_c(c), m_alpha(alpha), m_beta(beta) {} KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, const int k) const { - auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); - auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(m_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto cc = Kokkos::subview(m_c, k, Kokkos::ALL(), Kokkos::ALL()); - SerialGemm::invoke(_alpha, aa, bb, _beta, - cc); + SerialGemm::invoke(m_alpha, aa, bb, + m_beta, cc); } inline void run() { - typedef typename ViewType::value_type value_type; + using value_type = typename ViewType::non_const_value_type; std::string name_region("KokkosBatched::Test::SerialGemm"); const std::string name_value_type = Test::value_type_name(); std::string name = name_region + name_value_type; Kokkos::Profiling::pushRegion(name.c_str()); - Kokkos::RangePolicy policy(0, _c.extent(0)); + Kokkos::RangePolicy policy(0, m_c.extent(0)); Kokkos::parallel_for(name.c_str(), policy, *this); Kokkos::Profiling::popRegion(); } }; -template +/// \brief Implementation details of batched trsm analytical test + +/// \brief Implementation details of batched gemm test +/// \param N [in] Batch size of matrices +/// \param matAdim1 [in] Number of rows of matrix A +/// \param matAdim2 [in] Number of columns of matrix A +/// \param matBdim1 [in] Number of rows of matrix B +/// \param matBdim2 [in] Number of columns of matrix B +/// \param matCdim1 [in] Number of rows of matrix C +/// \param matCdim2 [in] Number of columns of matrix C +template void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, const int matBdim1, const int matBdim2, const int matCdim1, const int matCdim2) { using execution_space = typename DeviceType::execution_space; using transA = typename ParamTagType::transA; using transB = typename ParamTagType::transB; - using value_type = typename ViewType::value_type; - using ats = Kokkos::ArithTraits; + using ats = Kokkos::ArithTraits; + using ViewType = Kokkos::View; /// randomized input testing views ScalarType alpha = ScalarType(1.5); ScalarType beta = ScalarType(3.0); - ViewType a_expected("a_expected", N, matAdim1, matAdim2), a_actual("a_actual", N, matAdim1, matAdim2), - b_expected("b_expected", N, matBdim1, matBdim2), b_actual("b_actual", N, matBdim1, matBdim2), - c_expected("c_expected", N, matCdim1, matCdim2), c_actual("c_actual", N, matCdim1, matCdim2); + ViewType A("A", N, matAdim1, matAdim2), B("B", N, matBdim1, matBdim2), C("C", N, matCdim1, matCdim2), + C_ref("C_ref", N, matCdim1, matCdim2); - Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::Random_XorShift64_Pool rand_pool(13718); - Kokkos::fill_random(a_expected, random, value_type(1.0)); - Kokkos::fill_random(b_expected, random, value_type(1.0)); - Kokkos::fill_random(c_expected, random, value_type(1.0)); + ScalarType randStart, randEnd; + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(A, rand_pool, randStart, randEnd); + Kokkos::fill_random(B, rand_pool, randStart, randEnd); + Kokkos::fill_random(C, rand_pool, randStart, randEnd); - Kokkos::fence(); - - Kokkos::deep_copy(a_actual, a_expected); - Kokkos::deep_copy(b_actual, b_expected); - Kokkos::deep_copy(c_actual, c_expected); + Kokkos::deep_copy(C_ref, C); Functor_BatchedVanillaGEMM vgemm; - vgemm.A_t = std::is_same::value; - vgemm.B_t = std::is_same::value; - vgemm.A_c = vgemm.B_c = false; - vgemm.A = a_expected; - vgemm.B = b_expected; - vgemm.C = c_expected; - vgemm.alpha = alpha; - vgemm.beta = beta; - vgemm.run(); // Compute c_expected - Functor_TestBatchedSerialGemm(alpha, a_actual, b_actual, - beta, c_actual) + vgemm.A_t = !std::is_same_v; + vgemm.B_t = !std::is_same_v; + vgemm.A_c = std::is_same_v; + vgemm.B_c = std::is_same_v; + vgemm.A = A; + vgemm.B = B; + vgemm.C = C_ref; + vgemm.alpha = alpha; + vgemm.beta = beta; + vgemm.run(); // Compute C_ref + + // Compute using gemm API + Functor_TestBatchedSerialGemm(alpha, A, B, beta, C) .run(); - typename ViewType::HostMirror c_expected_host = Kokkos::create_mirror_view(c_expected); - typename ViewType::HostMirror c_actual_host = Kokkos::create_mirror_view(c_actual); - - // Copy to host for comparison - Kokkos::deep_copy(c_expected_host, c_expected); - Kokkos::deep_copy(c_actual_host, c_actual); - - Kokkos::fence(); + auto h_C = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C); + auto h_C_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C_ref); // check c_expected = c_actual // std::conditional<, float, @@ -130,19 +133,18 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, mag_type sum(1), diff(0); mag_type eps = ats::epsilon(); - - eps *= std::is_same::value || - std::is_same::value + eps *= std::is_same::value || + std::is_same::value ? 4 : 1e3; - for (int k = 0; k < N; ++k) - for (int i = 0; i < matCdim1; ++i) + for (int k = 0; k < N; ++k) { + for (int i = 0; i < matCdim1; ++i) { for (int j = 0; j < matCdim2; ++j) { - sum += ats::abs(c_expected_host(k, i, j)); - diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); + EXPECT_NEAR_KK(h_C(k, i, j), h_C_ref(k, i, j), eps); } - EXPECT_NEAR_KK(diff / sum, 0, eps); + } + } } } // namespace Gemm } // namespace Test @@ -151,37 +153,35 @@ template ViewType; - Test::Gemm::impl_test_batched_gemm(0, 10, 10, 10, 10, - 10, 10); + using LayoutType = Kokkos::LayoutLeft; + Test::Gemm::impl_test_batched_gemm( + 0, 10, 10, 10, 10, 10, 10); for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::Gemm::impl_test_batched_gemm(1024, i, i, i, i, - i, i); + Test::Gemm::impl_test_batched_gemm( + 1024, i, i, i, i, i, i); } for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); int dimM = i; int dimN = 2 * i; int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimN, dimK, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimN, dimK, dimM, dimN); } } @@ -189,37 +189,35 @@ int test_batched_gemm() { #endif #if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) { - typedef Kokkos::View ViewType; - Test::Gemm::impl_test_batched_gemm(0, 10, 10, 10, 10, - 10, 10); + using LayoutType = Kokkos::LayoutRight; + Test::Gemm::impl_test_batched_gemm( + 0, 10, 10, 10, 10, 10, 10); for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutRight, Blksize %d\n", i); - Test::Gemm::impl_test_batched_gemm(1024, i, i, i, i, - i, i); + Test::Gemm::impl_test_batched_gemm( + 1024, i, i, i, i, i, i); } for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); int dimM = i; int dimN = 2 * i; int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimN, dimK, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimN, dimK, dimM, dimN); } } diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp index f785965602..0d751869ad 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp @@ -13,72 +13,227 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) + +/// fcomplex, fcomplex + +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} + +/// fcomplex, float +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); +} + +#endif + #if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) /// dcomplex, dcomplex TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_serial_gemm_ct_nt_dcomplex_dcomplex ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_serial_gemm_nt_ct_dcomplex_dcomplex ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + Algo::Gemm::Unblocked>(); +} /// dcomplex, double - TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_serial_gemm_ct_nt_dcomplex_double ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,double,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_serial_gemm_nt_ct_dcomplex_double ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,double,param_tag_type,algo_tag_type>(); -// } + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); +} #endif diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp index afe5744688..f2028775ad 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp @@ -15,7 +15,7 @@ //@HEADER #if defined(KOKKOS_BHALF_T_IS_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); @@ -23,7 +23,7 @@ TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_bhalf_bhalf) { Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); @@ -31,7 +31,7 @@ TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_bhalf_bhalf) { Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); @@ -39,7 +39,7 @@ TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_bhalf_bhalf) { Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); @@ -50,28 +50,28 @@ TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_bhalf_bhalf) { #if defined(KOKKOS_HALF_T_IS_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); test_batched_gemm param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } #endif #if defined(KOKKOSKERNELS_INST_DOUBLE) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } + TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } #endif diff --git a/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp b/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp index 1a41ff4db3..2b9789cc02 100644 --- a/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp +++ b/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp @@ -16,26 +16,13 @@ #ifndef KOKKOSBLAS_INNER_MULTIPLE_DOT_PRODUCT_SERIAL_IMPL_HPP #define KOKKOSBLAS_INNER_MULTIPLE_DOT_PRODUCT_SERIAL_IMPL_HPP +#include "KokkosBlas_util.hpp" + /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace KokkosBlas { namespace Impl { -struct OpID { - template - KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { - return v; - } -}; - -struct OpConj { - template - KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { - using KAT = Kokkos::ArithTraits; - return KAT::conj(v); - } -}; - template struct InnerMultipleDotProduct { const int _as0, _as1, _xs0, _ys0; diff --git a/blas/impl/KokkosBlas_util.hpp b/blas/impl/KokkosBlas_util.hpp index c0777ac9ea..35661caef7 100644 --- a/blas/impl/KokkosBlas_util.hpp +++ b/blas/impl/KokkosBlas_util.hpp @@ -20,6 +20,22 @@ #include "Kokkos_ArithTraits.hpp" namespace KokkosBlas { +namespace Impl { +struct OpID { + template + KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { + return v; + } +}; + +struct OpConj { + template + KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { + using KAT = Kokkos::ArithTraits; + return KAT::conj(v); + } +}; +} // namespace Impl //////// Tags for BLAS //////// diff --git a/common/src/KokkosKernels_BlockUtils.hpp b/common/src/KokkosKernels_BlockUtils.hpp index 64309372ac..26a0baac67 100644 --- a/common/src/KokkosKernels_BlockUtils.hpp +++ b/common/src/KokkosKernels_BlockUtils.hpp @@ -52,13 +52,13 @@ KOKKOS_INLINE_FUNCTION void kk_block_add(const size_type block_dim, value_type * // Note: block is assumed to be row-major, dense matrix (no extra padding) // Note: set clear=true to set C = 0 before increment template > + typename DGEMM = KokkosBatched::Impl::SerialGemmInternal> KOKKOS_INLINE_FUNCTION void kk_block_dgemm(const size_type block_dim, value_type *dst, const value_type *valA, const value_type *valB, const bool clear = false) { const auto ZERO = static_cast(0); const auto ONE = static_cast(1); - DGEMM::invoke(block_dim, block_dim, block_dim, ONE, valA, block_dim, 1, valB, block_dim, 1, clear ? ZERO : ONE, dst, - block_dim, 1); + DGEMM::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), block_dim, block_dim, block_dim, ONE, valA, + block_dim, 1, valB, block_dim, 1, clear ? ZERO : ONE, dst, block_dim, 1); } // dgemm: C = A * B diff --git a/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp b/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp index 98501a5814..b3c870e8cc 100644 --- a/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp +++ b/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp @@ -64,10 +64,10 @@ void bspgemm_debug_numeric(KernelHandle* /* handle */, typename KernelHandle::nn typename cscalar_nnz_view_t_::HostMirror h_valc = Kokkos::create_mirror_view(valuesC); Kokkos::fence(); - typedef typename KernelHandle::nnz_lno_t lno_t; - typedef typename KernelHandle::size_type size_type; - typedef typename KernelHandle::nnz_scalar_t scalar_t; - typedef KokkosBatched::SerialGemmInternal GEMM; + using lno_t = typename KernelHandle::nnz_lno_t; + using size_type = typename KernelHandle::size_type; + using scalar_t = typename KernelHandle::nnz_scalar_t; + using GEMM = KokkosBatched::Impl::SerialGemmInternal; const auto block_size = block_dim * block_dim; const auto ZERO = static_cast(0); @@ -106,8 +106,8 @@ void bspgemm_debug_numeric(KernelHandle* /* handle */, typename KernelHandle::nn } // accumulator(b_col) += a_val * b_val auto acc = get_block(accumulator, b_col, block_size); - GEMM::invoke(block_dim, block_dim, block_dim, ONE, a_val, block_dim, 1, b_val, block_dim, 1, ONE, acc.data(), - block_dim, 1); + GEMM::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), block_dim, block_dim, block_dim, ONE, a_val, + block_dim, 1, b_val, block_dim, 1, ONE, acc.data(), block_dim, 1); } } diff --git a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp index d9702af900..3fae741c94 100644 --- a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp +++ b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp @@ -19,6 +19,7 @@ #include "KokkosKernels_Error.hpp" #include "KokkosKernels_ExecSpaceUtils.hpp" +#include "KokkosBlas_util.hpp" #if defined(KOKKOS_ENABLE_CUDA) && (defined(KOKKOS_ARCH_VOLTA) || defined(KOKKOS_ARCH_AMPERE)) @@ -1028,10 +1029,12 @@ struct BSR_GEMM_Functor { for (ordinal_type ic = 0; ic < count; ++ic) { const auto Aview = row.block(ic); const auto xstart = row.block_colidx(ic) * block_dim; - KokkosBatched::SerialGemmInternal::invoke( - static_cast(block_dim), static_cast(num_rhs), - static_cast(block_dim), alpha, Aview.data(), Aview.stride_0(), Aview.stride_1(), - &m_x(xstart, 0), m_x.stride_0(), ldx, beta1, &m_y(ystart, 0), m_y.stride_0(), ldy); + KokkosBatched::Impl::SerialGemmInternal::invoke< + KokkosBlas::Impl::OpID, KokkosBlas::Impl::OpID, value_type, value_type>( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), static_cast(block_dim), + static_cast(num_rhs), static_cast(block_dim), alpha, Aview.data(), + Aview.stride_0(), Aview.stride_1(), &m_x(xstart, 0), m_x.stride_0(), ldx, beta1, &m_y(ystart, 0), + m_y.stride_0(), ldy); } } } From 67a9246adb8e5aa1a86362754e49da4f930600dc Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 17 Dec 2024 10:58:49 +0900 Subject: [PATCH 2/7] improve checks in serial Gemm Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Gemm_Serial_Impl.hpp | 29 ++++++++++++------- .../KokkosBatched_Gemm_Serial_Internal.hpp | 14 +++++++++ .../unit_test/Test_Batched_SerialGemm.hpp | 17 +++++------ 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp index c539ffe4a2..266cc5bb33 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp @@ -29,22 +29,29 @@ KOKKOS_INLINE_FUNCTION static int checkGemmInput([[maybe_unused]] const AViewTyp static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: AViewType is not a Kokkos::View."); static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: BViewType is not a Kokkos::View."); static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: CViewType is not a Kokkos::View."); - /* - static_assert(AViewType::rank == 1 || AViewType::rank == 2, - "KokkosBatched::gemm: AViewType must have rank 1 or 2."); - static_assert(BViewType::rank == 1 || BViewType::rank == 2, - "KokkosBatched::gemm: BViewType must have rank 1 or 2."); - static_assert(CViewType::rank == 1 || CViewType::rank == 2, - "KokkosBatched::gemm: CViewType must have rank 1 or 2."); - */ + + static_assert(AViewType::rank <= 2, "KokkosBatched::gemm: AViewType must have rank 0, 1 or 2."); + static_assert(BViewType::rank <= 2, "KokkosBatched::gemm: BViewType must have rank 0, 1 or 2."); + static_assert(CViewType::rank <= 2, "KokkosBatched::gemm: CViewType must have rank 0, 1 or 2."); #if (KOKKOSKERNELS_DEBUG_LEVEL > 0) const int m = C.extent(0), n = C.extent(1); - const int lda = A.extent(0), k = A.extent(1); + const int lda = A.extent(0); const int ldb = B.extent(0); - const int nrowa = std::is_same_v ? m : k; - const int nrowb = std::is_same_v ? k : n; + const int ka = std::is_same_v ? A.extent(1) : A.extent(0); + const int kb = std::is_same_v ? B.extent(0) : B.extent(1); + + if (ka != kb) { + Kokkos::printf( + "KokkosBatched::gemm: Dimensions of A and B do not match: A: %d x %d, " + "B: %d x %d\n", + A.extent(0), A.extent(1), B.extent(0), B.extent(1)); + return 1; + } + + const int nrowa = std::is_same_v ? m : ka; + const int nrowb = std::is_same_v ? kb : n; if (lda < Kokkos::max(1, nrowa)) { Kokkos::printf( diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp index a66506e633..09ce343fa6 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp @@ -140,6 +140,20 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( } } // namespace Impl + +template +struct [[deprecated("Use KokkosBatched::SerialGemm instead")]] SerialGemmInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + return Impl::SerialGemmInternal::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m, n, k, alpha, + A, as0, as1, B, bs0, bs1, beta, C, cs0, cs1); + } +}; + } // namespace KokkosBatched #endif diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp index 776d83afe1..8ce5b8fd20 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp @@ -127,24 +127,23 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, auto h_C = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C); auto h_C_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C_ref); - // check c_expected = c_actual - // std::conditional<, float, + // check C = C_ref using mag_type = typename ats::mag_type; mag_type sum(1), diff(0); mag_type eps = ats::epsilon(); - eps *= std::is_same::value || - std::is_same::value + eps *= std::is_same_v || + std::is_same_v ? 4 : 1e3; - for (int k = 0; k < N; ++k) { - for (int i = 0; i < matCdim1; ++i) { + for (int k = 0; k < N; ++k) + for (int i = 0; i < matCdim1; ++i) for (int j = 0; j < matCdim2; ++j) { - EXPECT_NEAR_KK(h_C(k, i, j), h_C_ref(k, i, j), eps); + sum += ats::abs(h_C_ref(k, i, j)); + diff += ats::abs(h_C_ref(k, i, j) - h_C(k, i, j)); } - } - } + EXPECT_NEAR_KK(diff / sum, 0, eps); } } // namespace Gemm } // namespace Test From d4d754b13923cc4d450341da49fed007bf5324d7 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Thu, 19 Dec 2024 15:23:59 +0900 Subject: [PATCH 3/7] improve selective interface of batched gemm Signed-off-by: Yuuichi Asahi --- batched/dense/src/KokkosBatched_Gemm_Decl.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/batched/dense/src/KokkosBatched_Gemm_Decl.hpp b/batched/dense/src/KokkosBatched_Gemm_Decl.hpp index eabd5c42c2..1f3ba6095d 100644 --- a/batched/dense/src/KokkosBatched_Gemm_Decl.hpp +++ b/batched/dense/src/KokkosBatched_Gemm_Decl.hpp @@ -61,10 +61,12 @@ struct Gemm { KOKKOS_FORCEINLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { int r_val = 0; - if (std::is_same::value) { + if constexpr (std::is_same_v) { r_val = SerialGemm::invoke(alpha, A, B, beta, C); - } else if (std::is_same::value) { + } else if constexpr (std::is_same_v) { r_val = TeamGemm::invoke(member, alpha, A, B, beta, C); + } else if constexpr (std::is_same_v) { + r_val = TeamVectorGemm::invoke(member, alpha, A, B, beta, C); } return r_val; } From b82c9f6b2fd3813c900136a772e4394612050d67 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Thu, 19 Dec 2024 16:08:25 +0900 Subject: [PATCH 4/7] check info in serial gemm testing Signed-off-by: Yuuichi Asahi --- .../unit_test/Test_Batched_SerialGemm.hpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp index 8ce5b8fd20..32f58675b5 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp @@ -14,13 +14,11 @@ // //@HEADER /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "gtest/gtest.h" #include "Kokkos_Core.hpp" #include "Kokkos_Random.hpp" - -// #include "KokkosBatched_Vector.hpp" - #include "KokkosBatched_Gemm_Decl.hpp" #include "KokkosBatched_Gemm_Serial_Impl.hpp" @@ -50,29 +48,29 @@ struct Functor_TestBatchedSerialGemm { : m_a(a), m_b(b), m_c(c), m_alpha(alpha), m_beta(beta) {} KOKKOS_INLINE_FUNCTION - void operator()(const ParamTagType &, const int k) const { + void operator()(const ParamTagType &, const int k, int &info) const { auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); auto bb = Kokkos::subview(m_b, k, Kokkos::ALL(), Kokkos::ALL()); auto cc = Kokkos::subview(m_c, k, Kokkos::ALL(), Kokkos::ALL()); - SerialGemm::invoke(m_alpha, aa, bb, - m_beta, cc); + info += SerialGemm::invoke( + m_alpha, aa, bb, m_beta, cc); } - inline void run() { + inline int run() { using value_type = typename ViewType::non_const_value_type; std::string name_region("KokkosBatched::Test::SerialGemm"); const std::string name_value_type = Test::value_type_name(); std::string name = name_region + name_value_type; + int info_sum = 0; Kokkos::Profiling::pushRegion(name.c_str()); Kokkos::RangePolicy policy(0, m_c.extent(0)); - Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum); Kokkos::Profiling::popRegion(); + return info_sum; } }; -/// \brief Implementation details of batched trsm analytical test - /// \brief Implementation details of batched gemm test /// \param N [in] Batch size of matrices /// \param matAdim1 [in] Number of rows of matrix A @@ -121,8 +119,10 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, vgemm.run(); // Compute C_ref // Compute using gemm API - Functor_TestBatchedSerialGemm(alpha, A, B, beta, C) - .run(); + auto info = + Functor_TestBatchedSerialGemm(alpha, A, B, beta, C) + .run(); + EXPECT_EQ(info, 0); auto h_C = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C); auto h_C_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C_ref); From db254c2b8fa20e80c8021ab152c8f3ebba036c0c Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Thu, 19 Dec 2024 16:12:29 +0900 Subject: [PATCH 5/7] fix: op type of serial invoke Signed-off-by: Yuuichi Asahi --- batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp index 70c1ce3a03..ff4882b548 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp @@ -122,7 +122,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( i = ij / nq * mb; j = ij % nq * nb; } - inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), alpha, AA + i * as0, BB + j * bs1, + inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1); }); }; From 7e27f6265f84065affd502d97125600b432d5f59 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Thu, 19 Dec 2024 16:20:11 +0900 Subject: [PATCH 6/7] format Signed-off-by: Yuuichi Asahi --- ...okkosBatched_InnerGemmFixC_Serial_Impl.hpp | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp index 328fdbcf3d..b31bc895e2 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp @@ -362,14 +362,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); - a_1p = opA(A[i1 + p * _as1]); - b_p1 = opB(B[p * _bs0 + j1]); - a_2p = opA(A[i2 + p * _as1]); - b_p2 = opB(B[p * _bs0 + j2]); - a_3p = opA(A[i3 + p * _as1]); - b_p3 = opB(B[p * _bs0 + j3]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; @@ -438,12 +438,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); - a_1p = opA(A[i1 + p * _as1]); - b_p1 = opB(B[p * _bs0 + j1]); - a_2p = opA(A[i2 + p * _as1]); - b_p2 = opB(B[p * _bs0 + j2]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); /**/ b_p3 = opB(B[p * _bs0 + j3]); /**/ b_p4 = opB(B[p * _bs0 + j4]); @@ -503,10 +503,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); - a_1p = opA(A[i1 + p * _as1]); - b_p1 = opB(B[p * _bs0 + j1]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); /**/ b_p2 = opB(B[p * _bs0 + j2]); /**/ b_p3 = opB(B[p * _bs0 + j3]); /**/ b_p4 = opB(B[p * _bs0 + j4]); @@ -557,8 +557,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); /**/ b_p1 = opB(B[p * _bs0 + j1]); /**/ b_p2 = opB(B[p * _bs0 + j2]); /**/ b_p3 = opB(B[p * _bs0 + j3]); @@ -808,12 +808,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); - a_1p = opA(A[i1 + p * _as1]); - b_p1 = opB(B[p * _bs0 + j1]); - a_2p = opA(A[i2 + p * _as1]); - b_p2 = opB(B[p * _bs0 + j2]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; @@ -865,10 +865,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); - a_1p = opA(A[i1 + p * _as1]); - b_p1 = opB(B[p * _bs0 + j1]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); /**/ b_p2 = opB(B[p * _bs0 + j2]); /**/ b_p3 = opB(B[p * _bs0 + j3]); @@ -913,8 +913,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); /**/ b_p1 = opB(B[p * _bs0 + j1]); /**/ b_p2 = opB(B[p * _bs0 + j2]); /**/ b_p3 = opB(B[p * _bs0 + j3]); @@ -1078,10 +1078,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); - a_1p = opA(A[i1 + p * _as1]); - b_p1 = opB(B[p * _bs0 + j1]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; @@ -1119,8 +1119,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(OpA opA, OpB opB, #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = opA(A[i0 + p * _as1]); - b_p0 = opB(B[p * _bs0 + j0]); + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); /**/ b_p1 = opB(B[p * _bs0 + j1]); /**/ b_p2 = opB(B[p * _bs0 + j2]); From 198a7af13095f14bce93b61806c86e42d80bf330 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 7 Jan 2025 16:38:43 +0900 Subject: [PATCH 7/7] remove the global namespace Signed-off-by: Yuuichi Asahi --- .../unit_test/Test_Batched_SerialGemm.hpp | 15 +- .../Test_Batched_SerialGemm_Complex.hpp | 238 ++++++++++-------- .../Test_Batched_SerialGemm_Real.hpp | 100 ++++---- 3 files changed, 189 insertions(+), 164 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp index 32f58675b5..adc9a51aac 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp @@ -25,8 +25,6 @@ #include "KokkosKernels_TestUtils.hpp" #include "KokkosKernels_TestVanilla.hpp" -using namespace KokkosBatched; - namespace Test { namespace Gemm { @@ -53,8 +51,9 @@ struct Functor_TestBatchedSerialGemm { auto bb = Kokkos::subview(m_b, k, Kokkos::ALL(), Kokkos::ALL()); auto cc = Kokkos::subview(m_c, k, Kokkos::ALL(), Kokkos::ALL()); - info += SerialGemm::invoke( - m_alpha, aa, bb, m_beta, cc); + info += + KokkosBatched::SerialGemm::invoke( + m_alpha, aa, bb, m_beta, cc); } inline int run() { @@ -107,10 +106,10 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, Kokkos::deep_copy(C_ref, C); Functor_BatchedVanillaGEMM vgemm; - vgemm.A_t = !std::is_same_v; - vgemm.B_t = !std::is_same_v; - vgemm.A_c = std::is_same_v; - vgemm.B_c = std::is_same_v; + vgemm.A_t = !std::is_same_v; + vgemm.B_t = !std::is_same_v; + vgemm.A_c = std::is_same_v; + vgemm.B_c = std::is_same_v; vgemm.A = A; vgemm.B = B; vgemm.C = C_ref; diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp index 0d751869ad..ab97732238 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp @@ -18,105 +18,116 @@ /// fcomplex, fcomplex TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_fcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, Kokkos::complex, param_tag_type, Algo::Gemm::Blocked>(); + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } /// fcomplex, float TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, float, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, float, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); } #endif @@ -126,114 +137,125 @@ TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_float) { /// dcomplex, dcomplex TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_dcomplex_dcomplex) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = + ::Test::Gemm::ParamTag; test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Blocked>(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm, Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); + KokkosBatched::Algo::Gemm::Unblocked>(); } /// dcomplex, double TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_dcomplex_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm, double, param_tag_type, Algo::Gemm::Blocked>(); - test_batched_gemm, double, param_tag_type, Algo::Gemm::Unblocked>(); + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } #endif diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp index f2028775ad..0192b61b0f 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp @@ -15,113 +15,117 @@ //@HEADER #if defined(KOKKOS_BHALF_T_IS_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_bhalf_bhalf) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_bhalf_bhalf) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_bhalf_bhalf) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_bhalf_bhalf) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } #endif // KOKKOS_BHALF_T_IS_FLOAT #if defined(KOKKOS_HALF_T_IS_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_half_half) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_half_half) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_half_half) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_half_half) { - using param_tag_type = ::Test::Gemm::ParamTag; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } #endif // KOKKOS_HALF_T_IS_FLOAT #if defined(KOKKOSKERNELS_INST_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_float_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_float_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_float_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_float_float) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } #endif #if defined(KOKKOSKERNELS_INST_DOUBLE) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_double_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_double_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_double_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_double_double) { - using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } #endif