diff --git a/experimental/wasm/Makefile b/experimental/wasm/Makefile index 7ae769e..d64263b 100644 --- a/experimental/wasm/Makefile +++ b/experimental/wasm/Makefile @@ -1,31 +1,38 @@ +FLAGS=--target=wasm32 -nostdlib -Wl,--no-entry -Wl,--export-all -Wl,--import-memory -Wl,--allow-undefined -fexceptions -std=c++17 -O3 +# FLAGS=--target=wasm32-unknown-unknown -stdlib=libc++ -nostdlib++ -Wl,--no-entry -Wl,--export-all -Wl,--import-memory -Wl,--allow-undefined -fexceptions -std=c++17 + .PHONY: all clean dump-obj dump-wasm dependencies server -all: run.wasm dump-obj dump-wasm +all: build/hello.wasm dump-obj dump-wasm + +watch: + ls *.cpp *.h | entr make build/run.wasm -# Compile the C++ source file to LLVM IR -run.ll: run.cpp - clang --target=wasm32 -emit-llvm -c -S run.cpp +build/run.wasm: run.cpp Makefile + clang++ $(FLAGS) -o build/run.wasm run.cpp + +# cpp -> llvm ir +build/hello.ll: hello.cpp + clang --target=wasm32 -emit-llvm -c -S hello.cpp -o build/hello.ll -# Assemble the LLVM IR to a WebAssembly object file -run.o: run.ll - llc -march=wasm32 -filetype=obj run.ll +# llvm ir -> wasm object file +build/hello.o: build/hello.ll + llc -march=wasm32 -filetype=obj build/hello.ll -o build/hello.o # Disassemble the WebAssembly object file dump-obj: - wasm-objdump -x run.o + wasm-objdump -x build/hello.o # Link the WebAssembly object file to a WebAssembly module -# no entry point function -# export all functions -run.wasm: run.o +build/hello.wasm: build/hello.o wasm-ld \ --no-entry \ --export-all \ - -o run.wasm \ - run.o + -o build/hello.wasm \ + build/hello.o dump-wasm: - wasm-objdump -x run.wasm + wasm-objdump -x build/hello.wasm # TODO(avh): this is just a reminder note for now - remove it later dependencies: @@ -33,7 +40,7 @@ dependencies: brew install wabt server: - python3 -m http.server + python3 -m http.server 8000 clean: - rm -f run.ll run.o run.wasm + rm -f build/hello.ll build/hello.o build/hello.wasm build/run.wasm diff --git a/experimental/wasm/build/.gitkeep b/experimental/wasm/build/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/experimental/wasm/favicon.ico b/experimental/wasm/favicon.ico new file mode 100644 index 0000000..d748b70 Binary files /dev/null and b/experimental/wasm/favicon.ico differ diff --git a/experimental/wasm/gpu.js b/experimental/wasm/gpu.js index 921da54..a828bc5 100644 --- a/experimental/wasm/gpu.js +++ b/experimental/wasm/gpu.js @@ -1,5 +1,7 @@ // gpu.js +const gpujs = (function() { + class Shape { static kMaxRank = 8; @@ -7,7 +9,6 @@ class Shape { if (dims.length > Shape.kMaxRank) { throw new Error(`Shape can have at most ${Shape.kMaxRank} dimensions`); } - this.rank = dims.length; // Initialize data with the provided dimensions @@ -19,6 +20,7 @@ class Shape { } } } + class Array { constructor(buffer, usage, size) { this.buffer = buffer; @@ -187,6 +189,7 @@ async function createContext() { } context.device = await context.adapter.requestDevice(); context.queue = context.device.queue; + console.log("Context created"); return context; } @@ -299,7 +302,7 @@ function dispatchKernel(ctx, kernel) { return ctx.device.queue.onSubmittedWorkDone(); } -async function main() { +async function simpleTest() { console.log("Starting main"); const ctx = await createContext(); @@ -338,4 +341,36 @@ async function main() { destroyContext(ctx); } -main().catch(console.error); + // At the end of the file, return an object with all your exports + return { + Shape, + Array, + Tensor, + TensorView, + Bindings, + Context, + TensorPool, + KernelPool, + KernelCode, + Kernel, + NumType, + size, + sizeBytes, + toString, + replaceAll, + cdiv, + cdivShape, + createContext, + destroyContext, + resetCommandBuffer, + createKernel, + createTensor, + toGPU, + toCPU, + dispatchKernel, + simpleTest, + }; +})(); + + +export default gpujs; diff --git a/experimental/wasm/hello.cpp b/experimental/wasm/hello.cpp new file mode 100644 index 0000000..a92ee9a --- /dev/null +++ b/experimental/wasm/hello.cpp @@ -0,0 +1,7 @@ +// Hello world llvm wasm test + +extern "C" { + int add(int a, int b) { return a + b; } + int mul(int a, int b) { return a * b; } + int foo(int a, int b) { return a * a + b + 4; } +} diff --git a/experimental/wasm/index.html b/experimental/wasm/index.html index 8a6ae2f..029345f 100644 --- a/experimental/wasm/index.html +++ b/experimental/wasm/index.html @@ -1,13 +1,77 @@ - + - - - - WebGPU Context Creation - - + + + + gpu.cpp wasm test + +

gpu.js test

-
Initializing WebGPU...
- - +
gpu.cpp -> wasm test
+ + + diff --git a/experimental/wasm/run.cpp b/experimental/wasm/run.cpp index a92ee9a..4e6bbd0 100644 --- a/experimental/wasm/run.cpp +++ b/experimental/wasm/run.cpp @@ -1,7 +1,12 @@ -// Hello world llvm wasm test +#include "wasm.h" -extern "C" { - int add(int a, int b) { return a + b; } - int mul(int a, int b) { return a * b; } - int foo(int a, int b) { return a * a + b + 4; } +int main() { + // Note: This calls createContext but this doesn't work to obtain the return value + // due to async + // Context* ctx = createContext(); + // destroyContext(ctx); + + LOG("Hello, World!"); + + return 0; } diff --git a/experimental/wasm/wasm.h b/experimental/wasm/wasm.h new file mode 100644 index 0000000..335c2d7 --- /dev/null +++ b/experimental/wasm/wasm.h @@ -0,0 +1,122 @@ +#ifndef WASM_H +#define WASM_H + +// #define WASM_IMPORT __attribute__((import_module("env"), +// import_name("memory"))) #define WASM_IMPORT __attribute__((used)) +// __attribute__((visibility("default"))) + +extern "C" { + +// these are normally defined in stdint.h, but we can't include that in wasm +typedef signed char int8_t; +typedef short int16_t; +typedef int int32_t; +typedef long long int64_t; +typedef unsigned char uint8_t; +typedef unsigned short uint16_t; +typedef unsigned int uint32_t; +typedef unsigned long long uint64_t; +typedef unsigned long size_t; + +// Opaque handles to js shim objects +typedef struct Shape Shape; +typedef struct Array Array; +typedef struct Tensor Tensor; +typedef struct TensorView TensorView; +typedef struct Bindings Bindings; +typedef struct Context Context; +typedef struct KernelCode KernelCode; +typedef struct Kernel Kernel; + +// Enum to match JavaScript NumType +typedef enum { kf16, kf32 } NumType; + +// Function declarations that will be implemented in JavaScript + +Shape *createShape(int32_t *dims, int32_t rank); +void destroyShape(Shape *shape); + +Array *createArray(uint64_t bufferPtr, uint32_t usage, uint64_t size); +void destroyArray(Array *array); + +Tensor *createTensor(Array *data, Shape *shape); +void destroyTensor(Tensor *tensor); + +TensorView *createTensorView(Tensor *data, uint64_t offset, uint64_t span); +void destroyTensorView(TensorView *view); + +Bindings *createBindings(Tensor **tensors, int32_t count); +void destroyBindings(Bindings *bindings); + +Context *createContext(); +void destroyContext(Context *ctx); + +KernelCode *createKernelCode(const char *data, Shape *workgroupSize, + NumType precision); +void destroyKernelCode(KernelCode *code); + +Kernel *createKernel(Context *ctx, KernelCode *code, Bindings *dataBindings, + Shape *nWorkgroups, void *params); +void destroyKernel(Kernel *kernel); + +uint64_t size(Shape *shape); +uint64_t sizeBytes(NumType type); + +char *toString(Shape *shape); +char *toStringInt(int32_t value); +char *toStringNumType(NumType type); + +void replaceAll(char *str, const char *from, const char *to); + +int32_t cdiv(int32_t n, int32_t d); +Shape *cdivShape(Shape *total, Shape *group); + +Tensor *createTensorImpl(Context *ctx, Shape *shape, NumType dtype); + +void toGPU(Context *ctx, float *data, Tensor *tensor); +void toCPU(Context *ctx, Tensor *tensor, float *data); + +void dispatchKernel(Context *ctx, Kernel *kernel); + +void resetCommandBuffer(Context *ctx, Kernel *kernel); + +uint8_t *memory; + +void jsLOG(uint8_t *messagePtr); + +int simpleTest(); + +} // extern "C" + +// Simple bump allocator for now + +uint32_t kMemPtr = 0; + +uint8_t* wasmMalloc(size_t size) { + uint8_t* ptr = &memory[kMemPtr]; + kMemPtr += size; + return ptr; +} + +size_t strlen(const char* str) { + size_t len = 0; + while (str[len]) { + len++; + } + return len; +} + +void LOG(const char* message) { + size_t len = strlen(message); + uint8_t* start = (wasmMalloc(len)); + uint8_t* dest = start; + size_t index = 0; + while (*message) { + *dest = *message; + dest++; + message++; + } + jsLOG(start); +} + +#endif // WASM_H