From 47c4acb7c6c3a8dd529c00f1dcd99ea17b5ededb Mon Sep 17 00:00:00 2001 From: Evan Harvey Date: Mon, 22 Jun 2020 11:49:27 -0700 Subject: [PATCH] Intro-Full/Exercises/kokkoskernels: Move TeamGemm invocation to functor --- .../TeamGemm/Solution/teamgemm.cpp | 92 +++++++++++++------ 1 file changed, 65 insertions(+), 27 deletions(-) diff --git a/Intro-Full/Exercises/kokkoskernels/TeamGemm/Solution/teamgemm.cpp b/Intro-Full/Exercises/kokkoskernels/TeamGemm/Solution/teamgemm.cpp index 4167e135..65ab64bd 100755 --- a/Intro-Full/Exercises/kokkoskernels/TeamGemm/Solution/teamgemm.cpp +++ b/Intro-Full/Exercises/kokkoskernels/TeamGemm/Solution/teamgemm.cpp @@ -70,6 +70,63 @@ void checkDims(dims_t A, dims_t B) } } +template +struct functor_TeamGemm { + ScalarType alpha; + ViewType A; + ViewType B; + ScalarType beta; + ViewType C; + ViewTypeFilters filters; + const int team_size; + const int vector_size; + + functor_TeamGemm(ScalarType alpha_, + ViewType A_, ViewType B_, + ScalarType beta_, ViewType C_, + ViewTypeFilters filters_, + int team_size_, int vector_size_) : alpha(alpha_), + A(A_), + B(B_), + beta(beta_), + C(C_), + filters(filters_), + team_size(team_size_), + vector_size(vector_size_) + {} + + KOKKOS_INLINE_FUNCTION + void operator()(const TeamMemberType &member) const { + const int idx = member.league_rank(); + // Fetch 2D sub-matrices + auto a = Kokkos::subview(A, idx, Kokkos::ALL(), Kokkos::ALL()); + auto b = Kokkos::subview(B, idx, Kokkos::ALL(), Kokkos::ALL()); + auto c = Kokkos::subview(C, idx, Kokkos::ALL(), Kokkos::ALL()); + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, team_size), [&](const int &k0) { + // Fetch 1D column vectors + auto b_col_vec = Kokkos::subview(a, Kokkos::ALL(), k0); + auto c_col_vec = Kokkos::subview(c, Kokkos::ALL(), k0); + auto filter = Kokkos::subview(filters, Kokkos::ALL(), k0); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, vector_size), [&](const int &k1) { + // Filter each b column vector + b_col_vec(k1) *= filter(k1); + }); + // Calculate c_col_vec = beta*c_col_vec + alpha*a*b_col_vec + KokkosBatched::TeamGemm + ::invoke(member, alpha, a, b_col_vec, beta, c_col_vec); + }); + } +}; + + int main(int argc, char* argv[]) { dims_t A_dims, B_dims, C_dims; @@ -134,7 +191,7 @@ int main(int argc, char* argv[]) using LayoutType = Kokkos::LayoutRight; using DeviceType = Kokkos::Cuda; using ViewType = Kokkos::View; - + using FilterType = Kokkos::View; // Timer products struct timeval begin, end; @@ -147,7 +204,7 @@ int main(int argc, char* argv[]) ViewType A("A", N, A_dims.m, A_dims.n); ViewType B("B", N, B_dims.m, B_dims.n); ViewType C("C", N, C_dims.m, C_dims.n); - Kokkos::View filters("filters", B_dims.m, B_dims.n); + FilterType filters("filters", B_dims.m, B_dims.n); // Populate A, B, and C matrices with random numbers using ExecutionSpaceType = DeviceType::execution_space; @@ -163,37 +220,18 @@ int main(int argc, char* argv[]) // Invoke TeamGemm from Vector Loop const int num_leagues = N; /// N teams are formed - const int team_size = C_dims.n; /// Each team consists of C_dims.n kokkos threads - const int vector_size = C_dims.m; /// team_size * vector_size concurrent threads are associated within a team + int team_size = C_dims.n; /// Each team consists of C_dims.n kokkos threads + int vector_size = C_dims.m; /// team_size * vector_size concurrent threads are associated within a team using TeamMemberType = Kokkos::TeamPolicy::member_type; using ATransType = KokkosBatched::Trans::NoTranspose; using BTransType = KokkosBatched::Trans::NoTranspose; + using FunctorType = functor_TeamGemm; + + FunctorType functor(alpha, A, B, beta, C, filters, team_size, vector_size); Kokkos::TeamPolicy policy(num_leagues, team_size, vector_size); - Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const TeamMemberType &member) { - const int idx = member.league_rank(); - // Fetch 2D sub-matrices - auto a = Kokkos::subview(A, idx, Kokkos::ALL(), Kokkos::ALL()); - auto b = Kokkos::subview(B, idx, Kokkos::ALL(), Kokkos::ALL()); - auto c = Kokkos::subview(C, idx, Kokkos::ALL(), Kokkos::ALL()); - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, team_size), [&](const int &k0) { - // Fetch 1D column vectors - auto b_col_vec = Kokkos::subview(a, Kokkos::ALL(), k0); - auto c_col_vec = Kokkos::subview(c, Kokkos::ALL(), k0); - auto filter = Kokkos::subview(filters, Kokkos::ALL(), k0); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, vector_size), [&](const int &k1) { - // Filter each b column vector - b_col_vec(k1) *= filter(k1); - }); - // Calculate c_col_vec = beta*c_col_vec + alpha*a*b_col_vec - KokkosBatched::TeamGemm - ::invoke(member, alpha, a, b_col_vec, beta, c_col_vec); - }); - }); + Kokkos::parallel_for(policy, functor); // Wait for the device to return control Kokkos::fence();