-
Notifications
You must be signed in to change notification settings - Fork 16
/
gemm_naive.cu
35 lines (32 loc) · 1.17 KB
/
gemm_naive.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "util.cuh"
namespace {
__global__ void gemmKernel(const float *__restrict__ A,
const float *__restrict__ B, float *__restrict__ C,
float alpha, float beta, unsigned M, unsigned N,
unsigned K) {
unsigned int m = threadIdx.x + blockDim.x * blockIdx.x;
unsigned int n = threadIdx.y + blockDim.y * blockIdx.y;
float c = 0;
openmlsys::Tensor2D<const float> pA{A, M, K};
openmlsys::Tensor2D<const float> pB{B, K, N};
openmlsys::Tensor2D<float> pC{C, M, N};
if (!pC.validOffset(m, n)) return;
for (unsigned k = 0; k < K; ++k) {
c += pA(m, k) * pB(k, n);
}
c = c * alpha;
float result = c;
if (beta != 0) {
result = result + pC(m, n) * beta;
}
pC(m, n) = result;
}
} // namespace
void gemmNaive(const float *deviceAPtr, const float *deviceBPtr,
float *deviceCPtr, float alpha, float beta, unsigned M,
unsigned N, unsigned K) {
dim3 block(16, 16);
dim3 grid((M + block.x - 1) / block.x, (N + block.y - 1) / block.y);
gemmKernel<<<grid, block>>>(deviceAPtr, deviceBPtr, deviceCPtr, alpha, beta,
M, N, K);
}