Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul with float16 #39

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 143 additions & 73 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,35 @@
#include "utils/array_utils.h" // show, isclose, randn, randint
#include "utils/logging.h" // LOG
#include "experimental/wgsl.h" // loopUnrolling
#include "numeric_types/half.h"

using namespace gpu;

const std::string versionToStr(int version);

void matmulf16_forward_cpu(half* out,
const half* inp, const half* weight, const half* bias,
int B, int T, int C, int OC) {
// OC is short for "output channels"
// inp is (B,T,C), weight is (OC, C)
// out will be (B,T,OC)
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
half* out_bt = out + b * T * OC + t * OC;
const half* inp_bt = inp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? halfToFloat(bias[o]) : 0.0f;
const half* wrow = weight + o*C;
for (int i = 0; i < C; i++) {
val += halfToFloat(inp_bt[i]) * halfToFloat(wrow[i]);
}
out_bt[o] = val;
}
}
}
}

static const char *kShaderMatmul1 = R"(
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
Expand Down Expand Up @@ -47,7 +71,7 @@ inline KernelCode createMatmul1(const char *shaderTemplate, const size_t M,
{"{{M}}", toString(M)},
{"{{K}}", toString(K)},
{"{{N}}", toString(N)}});
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}

// Shared memory cache-blocking
Expand Down Expand Up @@ -108,7 +132,7 @@ inline KernelCode createMatmul2(const char *shaderTemplate, const size_t M,
{"{{N}}", toString(N)},
{"{{tileSize}}",
toString(static_cast<size_t>(sqrt(workgroupSize[0])))}});
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}

/* 1D block-tiling
Expand Down Expand Up @@ -224,9 +248,9 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
} else {
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}
}

Expand Down Expand Up @@ -340,9 +364,9 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
} else {
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}
}

Expand Down Expand Up @@ -462,9 +486,9 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
} else {
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}
}

Expand Down Expand Up @@ -582,7 +606,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
});
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
}

/**
Expand All @@ -604,7 +628,7 @@ inline KernelCode createNoOp(const char *shaderTemplate,
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)}});
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}

void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
Expand All @@ -619,23 +643,41 @@ void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
show<float>(weightsPtr.get(), N, K, "Weights").c_str());
}

void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
std::unique_ptr<float[]> &weightsPtr,
std::unique_ptr<float[]> &outputPtr) {
void initData(size_t M, size_t K, size_t N, std::unique_ptr<half[]> &inputPtr,
std::unique_ptr<half[]> &weightsPtr) {
std::mt19937 gen(314159);
randn(inputPtr.get(), M * K, gen);
randn(weightsPtr.get(), N * K, gen);
// randint(inputPtr.get(), M * K, gen, 1, 2);
// randint(weightsPtr.get(), N * K, gen, 1, 2);
LOG(kDefLog, kInfo, "%s", show<half>(inputPtr.get(), M, K, "Input").c_str());
LOG(kDefLog, kInfo, "%s",
show<half>(weightsPtr.get(), N, K, "Weights").c_str());
}

template<class precision=float>
void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<precision[]> &inputPtr,
std::unique_ptr<precision[]> &weightsPtr,
std::unique_ptr<precision[]> &outputPtr) {
LOG(kDefLog, kInfo, "Computing CPU reference implementation");
std::unique_ptr<float[]> outputRefPtr = std::make_unique<float[]>(M * N);
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
std::unique_ptr<precision[]> outputRefPtr = std::make_unique<precision[]>(M * N);
if constexpr (std::is_same<precision, float>::value) {
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
} else if constexpr (std::is_same<precision, half>::value) {
matmulf16_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
}
LOG(kDefLog, kInfo, "Reference Output: %s",
show<float>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
show<precision>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
LOG(kDefLog, kInfo,
isclose(outputPtr.get(), outputRefPtr.get(), M * N) ? "CPU Check: PASS"
: "CPU Check: FAIL");
}

Kernel selectMatmul(Context &ctx, int version,
const Bindings</* input, weights, output */ 3> &bindings,
size_t M, size_t K, size_t N) {
size_t M, size_t K, size_t N, NumType numtype) {
Kernel kernel;
if (version == 1) {
Shape wgSize = {256, 1, 1};
Expand All @@ -647,13 +689,13 @@ Kernel selectMatmul(Context &ctx, int version,
Shape wgSize = {16, 16, 1};
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
KernelCode matmul =
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
} else if (version == 3) {
static constexpr size_t tileSize = 16;
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
/*wgSize*/ {tileSize * tileSize, 1, 1});
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
kernel =
createKernel(ctx, matmul, bindings,
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
Expand All @@ -672,7 +714,7 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
/*wgSize*/ wgSize,
kf32,
numtype,
/*Loop unrolling*/ version == 6 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
Expand All @@ -690,11 +732,11 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize,
kf32,
numtype,
/*Loop unrolling*/ version == 7 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 8) {
} else if (version == 8 || version == 10) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
static constexpr size_t BN = 64;
Expand All @@ -708,11 +750,11 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize,
kf32,
numtype,
/*Loop unrolling*/ true);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 9) {
} else if (version == 9 || version == 11) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
static constexpr size_t BN = 64;
Expand All @@ -726,23 +768,36 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize,
kf32);
numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
}
return kernel;
}

template<class precision=float>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After learning my lesson from type-templated early iterations of gemma.cpp my bias here is to implement the core implementations with precision types at the value level as a default unless there's a rationale otherwise, but only for if there's a strong case.

Templated overloads are okay as a more typesafe convenience API sugar - it's easy to wrap a value-level core implementation in a more strongly templated wrapper. It's a bigger refactoring to go in the other direction and unwind a core implementation into value level.

My suggestion would be, unless there's some value I'm missing.

  • Either have runTest take a precision argument or have a separate runTest (e.g. runtest16) for alternative precision values.
  • avoid having getPrecision to pull out values from types, remove them entirely unless there's something we need them for.

void runTest(int version, size_t M, size_t K, size_t N,
std::unique_ptr<float[]> &inputPtr,
std::unique_ptr<float[]> &weightsPtr,
std::unique_ptr<float[]> &outputPtr) {
std::unique_ptr<precision[]> &inputPtr,
std::unique_ptr<precision[]> &weightsPtr,
std::unique_ptr<precision[]> &outputPtr,
NumType numtype) {
if constexpr (std::is_same<precision, float>::value) {
assert(numtype == kf32);
} else if constexpr (std::is_same<precision, half>::value) {
assert(numtype == kf16);
}

// Allocate GPU buffers and copy data
Context ctx = createContext();
Tensor input = createTensor(ctx, Shape{M, K}, kf32, inputPtr.get());
Tensor weights =
createTensor(ctx, Shape{N, K}, kf32, weightsPtr.get()); // column-major
Context ctx = createContext(
{}, {},
/*device descriptor, enabling f16 in WGSL*/
{
.requiredFeatureCount = 1,
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(),
});

Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major

constexpr size_t nIter = 30;

Expand All @@ -756,8 +811,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
std::array<Tensor, nIter> outputs;
for (int i = 0; i < nIter; i++) {
futures[i] = promises[i].get_future();
outputs[i] = createTensor(ctx, Shape{M, N}, kf32);
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N);
outputs[i] = createTensor(ctx, Shape{M, N}, numtype);
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
}

printf("[ Press enter to start tests ... ]\n");
Expand Down Expand Up @@ -785,9 +840,9 @@ void runTest(int version, size_t M, size_t K, size_t N,
1000000000.0 * static_cast<float>(nIter);

LOG(kDefLog, kInfo, "Copying result to CPU");
toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(float));
toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(precision));
LOG(kDefLog, kInfo, "%s",
show<float>(outputPtr.get(), M, N, "Output[0]").c_str());
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());

LOG(kDefLog, kInfo, "\n\n===================================================================="
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
Expand All @@ -798,33 +853,62 @@ void runTest(int version, size_t M, size_t K, size_t N,
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
}

template<class precision=float>
void runTestWithCheck(int version, size_t M, size_t K, size_t N,
bool transposedInput, int kTestSize, NumType numtype) {
std::unique_ptr<precision[]> inputPtr = std::make_unique<precision[]>(M * K);
std::unique_ptr<precision[]> weightsPtr = std::make_unique<precision[]>(N * K);
std::unique_ptr<precision[]> outputPtr = std::make_unique<precision[]>(M * N);

initData(M, K, N, inputPtr, weightsPtr);
if (transposedInput) {
std::unique_ptr<precision[]> transposedWeightPtr = std::make_unique<precision[]>(K * N);
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr, numtype);
} else {
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr, numtype);
}

if (kTestSize <= 1) {
// Check result with CPU reference implementation for tiny/small tests
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
}
}

const std::string versionToStr(int version){
switch (version) {
case 1: return "No-Op";
case 2: return "naive matmul";
case 3: return "tiling";
case 4: return "1D blocktiling";
case 5: return "2D blocktiling";
case 6: return "1D blocktiling with loop unrolling";
case 7: return "2D blocktiling with loop unrolling";
case 8: return "2D blocktiling with loop unrolling and vectorization";
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
case 1: return "f32: No-Op";
case 2: return "f32: naive matmul";
case 3: return "f32: tiling";
case 4: return "f32: 1D blocktiling";
case 5: return "f32: 2D blocktiling";
case 6: return "f32: 1D blocktiling with loop unrolling";
case 7: return "f32: 2D blocktiling with loop unrolling";
case 8: return "f32: 2D blocktiling with loop unrolling and vectorization";
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
default: return "Not specified";
}
}

int main() {
char* version_str = getenv("MATMUL_VERSION");
int version = version_str == NULL ? 9 : atoi(version_str);
// 1 == No-Op
// 2 == naive matmul
// 3 == tiling
// 4 == 1D blocktiling
// 5 == 2D blocktiling
// 6 == 1D blocktiling with loop unrolling
// 7 == 2D blocktiling with loop unrolling
// 8 == 2D blocktiling with loop unrolling and vectorization
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
int version = version_str == NULL ? 10 : atoi(version_str);
// 1 == f32: No-Op
// 2 == f32: naive matmul
// 3 == f32: tiling
// 4 == f32: 1D blocktiling
// 5 == f32: 2D blocktiling
// 6 == f32: 1D blocktiling with loop unrolling
// 7 == f32: 2D blocktiling with loop unrolling
// 8 == f32: 2D blocktiling with loop unrolling and vectorization
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
bool enableF16 = version == 10 || version ==11;
bool transposedInput = version == 9 || version == 11;
NumType numtype = enableF16 ? kf16 : kf32;

size_t M, K, N; // Matrix dimensions
char* kTestSize_str = getenv("MATMUL_SIZE");
Expand All @@ -846,24 +930,10 @@ int main() {
N = 2 * 4096;
}

std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
bool transposedInput = version == 9;

initData(M, K, N, inputPtr, weightsPtr);
if (transposedInput) {
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
if (enableF16) {
runTestWithCheck<half>(version, M, K, N, transposedInput, kTestSize, numtype);
} else {
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
}


if (kTestSize <= 1) {
// Check result with CPU reference implementation for tiny/small tests
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
runTestWithCheck<float>(version, M, K, N, transposedInput, kTestSize, numtype);
}

LOG(kDefLog, kInfo, "Done.");
Expand Down
3 changes: 3 additions & 0 deletions gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ struct KernelCode {
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32)
: data(pData), workgroupSize(workgroupSize), precision(precision) {
if (precision == kf16) {
data = "enable f16;\n" + data;
}
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
replaceAll(data, "{{precision}}", toString(precision));
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
Expand Down
Loading
Loading