Skip to content

Commit

Permalink
Merge pull request #2039 from eeprude/axpby_bug_fix
Browse files Browse the repository at this point in the history
Axpby bug fix (issue # 2015)
  • Loading branch information
ndellingwood authored Nov 24, 2023
2 parents 5df5171 + baab6f5 commit a80eb91
Show file tree
Hide file tree
Showing 4 changed files with 687 additions and 133 deletions.
29 changes: 25 additions & 4 deletions blas/impl/KokkosBlas1_axpby_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
// Nothing to do: m_y(i) = m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) =
Kokkos::ArithTraits<typename YV::non_const_value_type>::zero();
} else {
m_y(i) = m_b(0) * m_y(i);
}
}
}
// **************************************************************
Expand All @@ -137,7 +143,12 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
m_y(i) = -m_x(i) + m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = -m_x(i) + m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) = -m_x(i);
} else {
m_y(i) = -m_x(i) + m_b(0) * m_y(i);
}
}
}
// **************************************************************
Expand All @@ -151,7 +162,12 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
m_y(i) = m_x(i) + m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = m_x(i) + m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) = m_x(i);
} else {
m_y(i) = m_x(i) + m_b(0) * m_y(i);
}
}
}
// **************************************************************
Expand All @@ -165,7 +181,12 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
m_y(i) = m_a(0) * m_x(i) + m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = m_a(0) * m_x(i) + m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) = m_a(0) * m_x(i);
} else {
m_y(i) = m_a(0) * m_x(i) + m_b(0) * m_y(i);
}
}
}
}
Expand Down
157 changes: 137 additions & 20 deletions blas/impl/KokkosBlas1_axpby_mv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,28 @@ struct Axpby_MV_Functor {
// Nothing to do: Y(i,j) := Y(i,j)
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = Kokkos::ArithTraits<
typename YMV::non_const_value_type>::zero();
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -181,14 +195,27 @@ struct Axpby_MV_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = -m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -239,14 +266,27 @@ struct Axpby_MV_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -334,14 +374,27 @@ struct Axpby_MV_Functor {
} else if constexpr (scalar_y == 2) {
if (m_a.extent(0) == 1) {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand All @@ -356,14 +409,27 @@ struct Axpby_MV_Functor {
}
} else {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -715,11 +781,22 @@ struct Axpby_MV_Unroll_Functor {
// Nothing to do: Y(i,j) := Y(i,j)
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = Kokkos::ArithTraits<
typename YMV::non_const_value_type>::zero();
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down Expand Up @@ -758,11 +835,21 @@ struct Axpby_MV_Unroll_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = -m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down Expand Up @@ -801,11 +888,21 @@ struct Axpby_MV_Unroll_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down Expand Up @@ -872,11 +969,21 @@ struct Axpby_MV_Unroll_Functor {
} else if constexpr (scalar_y == 2) {
if (m_a.extent(0) == 1) {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand All @@ -888,11 +995,21 @@ struct Axpby_MV_Unroll_Functor {
}
} else {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down
Loading

0 comments on commit a80eb91

Please sign in to comment.