diff --git a/snippets/sgemm/gemm.deit.cc b/snippets/sgemm/gemm.deit.cc index f161a1f..e2a5f49 100644 --- a/snippets/sgemm/gemm.deit.cc +++ b/snippets/sgemm/gemm.deit.cc @@ -11,19 +11,19 @@ #include static const size_t N = 4000; -static const size_t ThreadBlockSizeI = 250; // 5周(×16スレッド) -static const size_t ThreadBlockSizeK = 4000; // 50周 -static const size_t ThreadBlockSizeJ = 4000; // 50周 +static const size_t ThreadBlockSizeI = N/16; // 250; // 5周(×16スレッド) +static const size_t ThreadBlockSizeK = N; // 50周 +static const size_t ThreadBlockSizeJ = N; // 50周 static const size_t L3CacheBlockSizeI = 50; // 1周 static const size_t L3CacheBlockSizeK = 80; // 2周 static const size_t L3CacheBlockSizeJ = 80; // 2周 static const size_t L1DCacheBlockSizeI = 50; // 10周 static const size_t L1DCacheBlockSizeK = 40; // 10周 -static const size_t L1DCacheBlockSizeJ = 40; // SIMD方向8要素×5周 +static const size_t L1DCacheBlockSizeJ = 80; // SIMD方向8要素×5周 static const size_t RegisterBlockSizeI = 5; // 5レジスタ並列に static const size_t RegisterBlockSizeK = 4; // fma連鎖4回 -void mm( const double *__restrict__ a, const double *__restrict__ b, double *__restrict__ c ) { +void mm( const float *__restrict__ a, const float *__restrict__ b, float *__restrict__ c ) { for( int i1 = 0; i1 < ThreadBlockSizeI; i1 += L3CacheBlockSizeI ) for( int k1 = 0; k1 < ThreadBlockSizeK; k1 += L3CacheBlockSizeK ) for( int j1 = 0; j1 < ThreadBlockSizeJ; j1 += L3CacheBlockSizeJ ) @@ -44,9 +44,9 @@ void mm( const double *__restrict__ a, const double *__restrict__ b, double *__r } } -alignas(64) double ah[N*N]; -alignas(64) double bh[N*N]; -alignas(64) double ch[N*N], ch_cblas[N*N]; +alignas(64) float ah[N*N]; +alignas(64) float bh[N*N]; +alignas(64) float ch[N*N], ch_cblas[N*N]; int main() { std::mt19937_64 mt; @@ -73,15 +73,16 @@ int main() { const auto finish = std::chrono::system_clock::now(); - cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, ah, N, bh, N, 1.0, ch_cblas, N); + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, ah, N, bh, N, 1.0, ch_cblas, N); for( int i = 0; i < N*N; ++i ) - if( std::abs( ch[i] - ch_cblas[i] ) > 1e-6 ) + if( std::abs( ch[i] - ch_cblas[i] ) > 1e-2 ) { + std::cerr << "mismatch at " << i << ": " << ch[i] << " != " << ch_cblas[i] << std::endl; exit(EXIT_FAILURE); + } const double s = std::chrono::duration_cast( finish - start ).count() * 1e-9; static constexpr double flop_per_fma = 2.0; - static constexpr double insn_per_fma = 1.0 / 8.0; std::cout << s << " seconds, " << N*N*N*flop_per_fma/s * 1e-9 << " GFLOPS" << std::endl; } }