Skip to content

Commit

Permalink
snippets: fp32 gemm.deit.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawatoshiki committed Oct 5, 2024
1 parent 0a725ee commit 493eaa3
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions snippets/sgemm/gemm.deit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
#include <blis.h>

static const size_t N = 4000;
static const size_t ThreadBlockSizeI = 250; // 5周(×16スレッド)
static const size_t ThreadBlockSizeK = 4000; // 50周
static const size_t ThreadBlockSizeJ = 4000; // 50周
static const size_t ThreadBlockSizeI = N/16; // 250; // 5周(×16スレッド)
static const size_t ThreadBlockSizeK = N; // 50周
static const size_t ThreadBlockSizeJ = N; // 50周
static const size_t L3CacheBlockSizeI = 50; // 1周
static const size_t L3CacheBlockSizeK = 80; // 2周
static const size_t L3CacheBlockSizeJ = 80; // 2周
static const size_t L1DCacheBlockSizeI = 50; // 10周
static const size_t L1DCacheBlockSizeK = 40; // 10周
static const size_t L1DCacheBlockSizeJ = 40; // SIMD方向8要素×5周
static const size_t L1DCacheBlockSizeJ = 80; // SIMD方向8要素×5周
static const size_t RegisterBlockSizeI = 5; // 5レジスタ並列に
static const size_t RegisterBlockSizeK = 4; // fma連鎖4回

void mm( const double *__restrict__ a, const double *__restrict__ b, double *__restrict__ c ) {
void mm( const float *__restrict__ a, const float *__restrict__ b, float *__restrict__ c ) {
for( int i1 = 0; i1 < ThreadBlockSizeI; i1 += L3CacheBlockSizeI )
for( int k1 = 0; k1 < ThreadBlockSizeK; k1 += L3CacheBlockSizeK )
for( int j1 = 0; j1 < ThreadBlockSizeJ; j1 += L3CacheBlockSizeJ )
Expand All @@ -44,9 +44,9 @@ void mm( const double *__restrict__ a, const double *__restrict__ b, double *__r
}
}

alignas(64) double ah[N*N];
alignas(64) double bh[N*N];
alignas(64) double ch[N*N], ch_cblas[N*N];
alignas(64) float ah[N*N];
alignas(64) float bh[N*N];
alignas(64) float ch[N*N], ch_cblas[N*N];

int main() {
std::mt19937_64 mt;
Expand All @@ -73,15 +73,16 @@ int main() {

const auto finish = std::chrono::system_clock::now();

cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, ah, N, bh, N, 1.0, ch_cblas, N);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, ah, N, bh, N, 1.0, ch_cblas, N);

for( int i = 0; i < N*N; ++i )
if( std::abs( ch[i] - ch_cblas[i] ) > 1e-6 )
if( std::abs( ch[i] - ch_cblas[i] ) > 1e-2 ) {
std::cerr << "mismatch at " << i << ": " << ch[i] << " != " << ch_cblas[i] << std::endl;
exit(EXIT_FAILURE);
}

const double s = std::chrono::duration_cast<std::chrono::nanoseconds>( finish - start ).count() * 1e-9;
static constexpr double flop_per_fma = 2.0;
static constexpr double insn_per_fma = 1.0 / 8.0;
std::cout << s << " seconds, " << N*N*N*flop_per_fma/s * 1e-9 << " GFLOPS" << std::endl;
}
}
Expand Down

0 comments on commit 493eaa3

Please sign in to comment.