Skip to content

Commit

Permalink
Intro-Full/Exercises/kokkoskernels: Move TeamGemm invocation to functor
Browse files Browse the repository at this point in the history
  • Loading branch information
e10harvey committed Jun 22, 2020
1 parent f3869e0 commit 47c4acb
Showing 1 changed file with 65 additions and 27 deletions.
92 changes: 65 additions & 27 deletions Intro-Full/Exercises/kokkoskernels/TeamGemm/Solution/teamgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,63 @@ void checkDims(dims_t A, dims_t B)
}
}

template <class TeamMemberType,
class ScalarType,
class ViewType,
class ATransType,
class BTransType,
class ViewTypeFilters>
struct functor_TeamGemm {
ScalarType alpha;
ViewType A;
ViewType B;
ScalarType beta;
ViewType C;
ViewTypeFilters filters;
const int team_size;
const int vector_size;

functor_TeamGemm(ScalarType alpha_,
ViewType A_, ViewType B_,
ScalarType beta_, ViewType C_,
ViewTypeFilters filters_,
int team_size_, int vector_size_) : alpha(alpha_),
A(A_),
B(B_),
beta(beta_),
C(C_),
filters(filters_),
team_size(team_size_),
vector_size(vector_size_)
{}

KOKKOS_INLINE_FUNCTION
void operator()(const TeamMemberType &member) const {
const int idx = member.league_rank();
// Fetch 2D sub-matrices
auto a = Kokkos::subview(A, idx, Kokkos::ALL(), Kokkos::ALL());
auto b = Kokkos::subview(B, idx, Kokkos::ALL(), Kokkos::ALL());
auto c = Kokkos::subview(C, idx, Kokkos::ALL(), Kokkos::ALL());
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, team_size), [&](const int &k0) {
// Fetch 1D column vectors
auto b_col_vec = Kokkos::subview(a, Kokkos::ALL(), k0);
auto c_col_vec = Kokkos::subview(c, Kokkos::ALL(), k0);
auto filter = Kokkos::subview(filters, Kokkos::ALL(), k0);
Kokkos::parallel_for(Kokkos::TeamThreadRange(member, vector_size), [&](const int &k1) {
// Filter each b column vector
b_col_vec(k1) *= filter(k1);
});
// Calculate c_col_vec = beta*c_col_vec + alpha*a*b_col_vec
KokkosBatched::TeamGemm<TeamMemberType,
ATransType,
BTransType,
KokkosBatched::Algo::Gemm::Unblocked>
::invoke(member, alpha, a, b_col_vec, beta, c_col_vec);
});
}
};


int main(int argc, char* argv[])
{
dims_t A_dims, B_dims, C_dims;
Expand Down Expand Up @@ -134,7 +191,7 @@ int main(int argc, char* argv[])
using LayoutType = Kokkos::LayoutRight;
using DeviceType = Kokkos::Cuda;
using ViewType = Kokkos::View<ScalarType***, LayoutType, DeviceType>;

using FilterType = Kokkos::View<ScalarType**, LayoutType, DeviceType>;

// Timer products
struct timeval begin, end;
Expand All @@ -147,7 +204,7 @@ int main(int argc, char* argv[])
ViewType A("A", N, A_dims.m, A_dims.n);
ViewType B("B", N, B_dims.m, B_dims.n);
ViewType C("C", N, C_dims.m, C_dims.n);
Kokkos::View<ScalarType**, LayoutType, DeviceType> filters("filters", B_dims.m, B_dims.n);
FilterType filters("filters", B_dims.m, B_dims.n);

// Populate A, B, and C matrices with random numbers
using ExecutionSpaceType = DeviceType::execution_space;
Expand All @@ -163,37 +220,18 @@ int main(int argc, char* argv[])

// Invoke TeamGemm from Vector Loop
const int num_leagues = N; /// N teams are formed
const int team_size = C_dims.n; /// Each team consists of C_dims.n kokkos threads
const int vector_size = C_dims.m; /// team_size * vector_size concurrent threads are associated within a team
int team_size = C_dims.n; /// Each team consists of C_dims.n kokkos threads
int vector_size = C_dims.m; /// team_size * vector_size concurrent threads are associated within a team

using TeamMemberType = Kokkos::TeamPolicy<ExecutionSpaceType>::member_type;
using ATransType = KokkosBatched::Trans::NoTranspose;
using BTransType = KokkosBatched::Trans::NoTranspose;
using FunctorType = functor_TeamGemm<TeamMemberType, ScalarType, ViewType, ATransType, BTransType, FilterType>;

FunctorType functor(alpha, A, B, beta, C, filters, team_size, vector_size);
Kokkos::TeamPolicy<ExecutionSpaceType> policy(num_leagues, team_size, vector_size);

Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const TeamMemberType &member) {
const int idx = member.league_rank();
// Fetch 2D sub-matrices
auto a = Kokkos::subview(A, idx, Kokkos::ALL(), Kokkos::ALL());
auto b = Kokkos::subview(B, idx, Kokkos::ALL(), Kokkos::ALL());
auto c = Kokkos::subview(C, idx, Kokkos::ALL(), Kokkos::ALL());
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, team_size), [&](const int &k0) {
// Fetch 1D column vectors
auto b_col_vec = Kokkos::subview(a, Kokkos::ALL(), k0);
auto c_col_vec = Kokkos::subview(c, Kokkos::ALL(), k0);
auto filter = Kokkos::subview(filters, Kokkos::ALL(), k0);
Kokkos::parallel_for(Kokkos::TeamThreadRange(member, vector_size), [&](const int &k1) {
// Filter each b column vector
b_col_vec(k1) *= filter(k1);
});
// Calculate c_col_vec = beta*c_col_vec + alpha*a*b_col_vec
KokkosBatched::TeamGemm<TeamMemberType,
ATransType,
BTransType,
KokkosBatched::Algo::Gemm::Unblocked>
::invoke(member, alpha, a, b_col_vec, beta, c_col_vec);
});
});
Kokkos::parallel_for(policy, functor);

// Wait for the device to return control
Kokkos::fence();
Expand Down

0 comments on commit 47c4acb

Please sign in to comment.