Skip to content

Commit

Permalink
wasm: Add compile target option when creating slang session (#5403)
Browse files Browse the repository at this point in the history
* wasm: Add compile target option when creating slang session

Also add a new interface to return spirv code which is binary,

because 'std::string ComponentType::getEntryPointCode' is not
suitable for returning the binary data.

We use a more standard way that wrap the binary data by using
emscripten::val as the return type.

* Add target of metal
  • Loading branch information
kaizhangNV authored Oct 24, 2024
1 parent ee709cf commit 46b8ab8
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 4 deletions.
16 changes: 15 additions & 1 deletion source/slang-wasm/slang-wasm-bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ EMSCRIPTEN_BINDINGS(slang)
"getLastError",
&slang::wgsl::getLastError);

function(
"getCompileTargets",
&slang::wgsl::getCompileTargets,
return_value_policy::take_ownership());

class_<slang::wgsl::GlobalSession>("GlobalSession")
.function(
"createSession",
Expand All @@ -40,7 +45,10 @@ EMSCRIPTEN_BINDINGS(slang)
return_value_policy::take_ownership())
.function(
"getEntryPointCode",
&slang::wgsl::ComponentType::getEntryPointCode);
&slang::wgsl::ComponentType::getEntryPointCode)
.function(
"getEntryPointCodeSpirv",
&slang::wgsl::ComponentType::getEntryPointCodeSpirv);

class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module")
.function(
Expand All @@ -59,5 +67,11 @@ EMSCRIPTEN_BINDINGS(slang)

class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint");

class_<slang::wgsl::CompileTargets>("CompileTargets")
.function(
"findCompileTarget",
&slang::wgsl::CompileTargets::findCompileTarget,
return_value_policy::take_ownership());

register_vector<slang::wgsl::ComponentType*>("ComponentTypeList");
}
63 changes: 61 additions & 2 deletions source/slang-wasm/slang-wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace wgsl
{

Error g_error;
CompileTargets g_compileTargets;

Error getLastError()
{
Expand All @@ -22,6 +23,11 @@ Error getLastError()
return currentError;
}

CompileTargets* getCompileTargets()
{
return &g_compileTargets;
}

GlobalSession* createGlobalSession()
{
IGlobalSession* globalSession = nullptr;
Expand All @@ -38,15 +44,41 @@ GlobalSession* createGlobalSession()
return new GlobalSession(globalSession);
}

Session* GlobalSession::createSession()
CompileTargets::CompileTargets()
{
#define MAKE_PAIR(x) { #x, SLANG_##x }

m_compileTargetMap = {
MAKE_PAIR(GLSL),
MAKE_PAIR(HLSL),
MAKE_PAIR(WGSL),
MAKE_PAIR(SPIRV),
MAKE_PAIR(METAL),
};
}

int CompileTargets::findCompileTarget(const std::string& name)
{
auto res = m_compileTargetMap.find(name);
if ( res != m_compileTargetMap.end())
{
return res->second;
}
else
{
return SLANG_TARGET_UNKNOWN;
}
}

Session* GlobalSession::createSession(int compileTarget)
{
ISession* session = nullptr;
{
SessionDesc sessionDesc = {};
sessionDesc.structureSize = sizeof(sessionDesc);
constexpr SlangInt targetCount = 1;
TargetDesc target = {};
target.format = SLANG_WGSL;
target.format = (SlangCompileTarget)compileTarget;
sessionDesc.targets = &target;
sessionDesc.targetCount = targetCount;
SlangResult result = m_interface->createSession(sessionDesc, &session);
Expand Down Expand Up @@ -202,5 +234,32 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde
return {};
}

// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val
// to wrap it and return it to the javascript side.
emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex)
{
Slang::ComPtr<IBlob> kernelBlob;
Slang::ComPtr<ISlangBlob> diagnosticBlob;
SlangResult result = interface()->getEntryPointCode(
entryPointIndex,
targetIndex,
kernelBlob.writeRef(),
diagnosticBlob.writeRef());
if (result != SLANG_OK)
{
g_error.type = std::string("USER");
g_error.result = result;
g_error.message = std::string(
(char*)diagnosticBlob->getBufferPointer(),
(char*)diagnosticBlob->getBufferPointer() +
diagnosticBlob->getBufferSize());
return {};
}

const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer();
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(),
ptr));
}

} // namespace wgsl
} // namespace slang
16 changes: 15 additions & 1 deletion source/slang-wasm/slang-wasm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <slang.h>
#include <unordered_map>
#include <emscripten/val.h>

namespace slang
{
Expand All @@ -20,6 +22,17 @@ class Error

Error getLastError();

class CompileTargets
{
public:
CompileTargets();
int findCompileTarget(const std::string& name);
private:
std::unordered_map<std::string, SlangCompileTarget> m_compileTargetMap;
};

CompileTargets* getCompileTargets();

class ComponentType
{
public:
Expand All @@ -30,6 +43,7 @@ class ComponentType
ComponentType* link();

std::string getEntryPointCode(int entryPointIndex, int targetIndex);
emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex);

slang::IComponentType* interface() const {return m_interface;}

Expand Down Expand Up @@ -93,7 +107,7 @@ class GlobalSession
GlobalSession(slang::IGlobalSession* interface)
: m_interface(interface) {}

Session* createSession();
Session* createSession(int compileTarget);

slang::IGlobalSession* interface() const {return m_interface;}

Expand Down

0 comments on commit 46b8ab8

Please sign in to comment.