Skip to content

Commit

Permalink
More wasm binding for playground. (shader-slang#5420)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Oct 28, 2024
1 parent a3276e2 commit 0432907
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 14 deletions.
25 changes: 21 additions & 4 deletions source/slang-wasm/slang-wasm-bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ EMSCRIPTEN_BINDINGS(slang)
"getEntryPointCode",
&slang::wgsl::ComponentType::getEntryPointCode)
.function(
"getEntryPointCodeSpirv",
&slang::wgsl::ComponentType::getEntryPointCodeSpirv);
"getEntryPointCodeBlob",
&slang::wgsl::ComponentType::getEntryPointCodeBlob)
.function(
"getTargetCodeBlob",
&slang::wgsl::ComponentType::getTargetCodeBlob)
.function(
"getTargetCode",
&slang::wgsl::ComponentType::getTargetCode);

class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module")
.function(
Expand All @@ -58,14 +64,25 @@ EMSCRIPTEN_BINDINGS(slang)
.function(
"findAndCheckEntryPoint",
&slang::wgsl::Module::findAndCheckEntryPoint,
return_value_policy::take_ownership());
return_value_policy::take_ownership())
.function(
"getDefinedEntryPoint",
&slang::wgsl::Module::getDefinedEntryPoint,
return_value_policy::take_ownership())
.function(
"getDefinedEntryPointCount",
&slang::wgsl::Module::getDefinedEntryPointCount);

value_object<slang::wgsl::Error>("Error")
.field("type", &slang::wgsl::Error::type)
.field("result", &slang::wgsl::Error::result)
.field("message", &slang::wgsl::Error::message);

class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint");
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint")
.function(
"getName",
&slang::wgsl::EntryPoint::getName,
allow_raw_pointers());

class_<slang::wgsl::CompileTargets>("CompileTargets")
.function(
Expand Down
96 changes: 90 additions & 6 deletions source/slang-wasm/slang-wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,15 @@ Session* GlobalSession::createSession(int compileTarget)
return new Session(session);
}

Module* Session::loadModuleFromSource(const std::string& slangCode)
Module* Session::loadModuleFromSource(const std::string& slangCode, const std::string& name, const std::string& path)
{
Slang::ComPtr<IModule> module;
{
const char * name = "";
const char * path = "";
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
Slang::ComPtr<ISlangBlob> slangCodeBlob = Slang::RawBlob::create(
slangCode.c_str(), slangCode.size());
module = m_interface->loadModuleFromSource(
name, path, slangCodeBlob, diagnosticsBlob.writeRef());
name.c_str(), path.c_str(), slangCodeBlob, diagnosticsBlob.writeRef());
if (!module)
{
g_error.type = std::string("USER");
Expand Down Expand Up @@ -161,6 +159,38 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
return new EntryPoint(entryPoint);
}

int Module::getDefinedEntryPointCount()
{
return moduleInterface()->getDefinedEntryPointCount();
}

EntryPoint* Module::getDefinedEntryPoint(int index)
{
if (moduleInterface()->getDefinedEntryPointCount() <= index)
return nullptr;

Slang::ComPtr<IEntryPoint> entryPoint;
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
SlangResult result = moduleInterface()->getDefinedEntryPoint(index, entryPoint.writeRef());
if (!SLANG_SUCCEEDED(result))
{
g_error.type = std::string("USER");
g_error.result = result;

if (diagnosticsBlob->getBufferSize())
{
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
return nullptr;
}
}

return new EntryPoint(entryPoint);
}


ComponentType* Session::createCompositeComponentType(
const std::vector<ComponentType*>& components)
{
Expand Down Expand Up @@ -235,9 +265,9 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde
return {};
}

// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val
// Since result code is binary, we can't return it as a string, we will need to use emscripten::val
// to wrap it and return it to the javascript side.
emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex)
emscripten::val ComponentType::getEntryPointCodeBlob(int entryPointIndex, int targetIndex)
{
Slang::ComPtr<IBlob> kernelBlob;
Slang::ComPtr<ISlangBlob> diagnosticBlob;
Expand All @@ -262,6 +292,60 @@ emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int t
ptr));
}

std::string ComponentType::getTargetCode(int targetIndex)
{
{
Slang::ComPtr<IBlob> kernelBlob;
Slang::ComPtr<ISlangBlob> diagnosticBlob;
SlangResult result = interface()->getTargetCode(
targetIndex,
kernelBlob.writeRef(),
diagnosticBlob.writeRef());
if (result != SLANG_OK)
{
g_error.type = std::string("USER");
g_error.result = result;
g_error.message = std::string(
(char*)diagnosticBlob->getBufferPointer(),
(char*)diagnosticBlob->getBufferPointer() +
diagnosticBlob->getBufferSize());
return "";
}
std::string targetCode = std::string(
(char*)kernelBlob->getBufferPointer(),
(char*)kernelBlob->getBufferPointer() + kernelBlob->getBufferSize());
return targetCode;
}

return {};
}

// Since result code is binary, we can't return it as a string, we will need to use emscripten::val
// to wrap it and return it to the javascript side.
emscripten::val ComponentType::getTargetCodeBlob(int targetIndex)
{
Slang::ComPtr<IBlob> kernelBlob;
Slang::ComPtr<ISlangBlob> diagnosticBlob;
SlangResult result = interface()->getTargetCode(
targetIndex,
kernelBlob.writeRef(),
diagnosticBlob.writeRef());
if (result != SLANG_OK)
{
g_error.type = std::string("USER");
g_error.result = result;
g_error.message = std::string(
(char*)diagnosticBlob->getBufferPointer(),
(char*)diagnosticBlob->getBufferPointer() +
diagnosticBlob->getBufferSize());
return {};
}

const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer();
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(),
ptr));
}

namespace lsp
{
Position translate(Slang::LanguageServerProtocol::Position p)
Expand Down
15 changes: 11 additions & 4 deletions source/slang-wasm/slang-wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class ComponentType
ComponentType* link();

std::string getEntryPointCode(int entryPointIndex, int targetIndex);
emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex);
emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex);
std::string getTargetCode(int targetIndex);
emscripten::val getTargetCodeBlob(int targetIndex);

slang::IComponentType* interface() const {return m_interface;}

Expand All @@ -62,9 +64,11 @@ class ComponentType
class EntryPoint : public ComponentType
{
public:

EntryPoint(slang::IEntryPoint* interface) : ComponentType(interface) {}

std::string getName() const
{
return entryPointInterface()->getFunctionReflection()->getName();
}
private:

slang::IEntryPoint* entryPointInterface() const {
Expand All @@ -80,6 +84,8 @@ class Module : public ComponentType

EntryPoint* findEntryPointByName(const std::string& name);
EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage);
EntryPoint* getDefinedEntryPoint(int index);
int getDefinedEntryPointCount();

slang::IModule* moduleInterface() const {
return static_cast<slang::IModule*>(interface());
Expand All @@ -93,7 +99,8 @@ class Session
Session(slang::ISession* interface)
: m_interface(interface) {}

Module* loadModuleFromSource(const std::string& slangCode);
Module* loadModuleFromSource(
const std::string& slangCode, const std::string& name, const std::string& path);

ComponentType* createCompositeComponentType(
const std::vector<ComponentType*>& components);
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5040,13 +5040,21 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD
});
List<RefPtr<ComponentType>> components;
components.add(this);
bool entryPointsDiscovered = false;
for (auto module : modules)
{
for (auto entryPoint : module->getEntryPoints())
{
components.add(entryPoint);
entryPointsDiscovered = true;
}
}
// If no entry points were discovered, then we should return nullptr.
if (!entryPointsDiscovered)
{
return nullptr;
}

RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components);
ComPtr<IComponentType> linkedComponentType;
SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics));
Expand Down

0 comments on commit 0432907

Please sign in to comment.