From 46b8ab8353966f2590ed2667028b220b57f963ae Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:42:28 -0500 Subject: [PATCH] wasm: Add compile target option when creating slang session (#5403) * wasm: Add compile target option when creating slang session Also add a new interface to return spirv code which is binary, because 'std::string ComponentType::getEntryPointCode' is not suitable for returning the binary data. We use a more standard way that wrap the binary data by using emscripten::val as the return type. * Add target of metal --- source/slang-wasm/slang-wasm-bindings.cpp | 16 +++++- source/slang-wasm/slang-wasm.cpp | 63 ++++++++++++++++++++++- source/slang-wasm/slang-wasm.h | 16 +++++- 3 files changed, 91 insertions(+), 4 deletions(-) diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp index f8175180a3..d033f3846c 100644 --- a/source/slang-wasm/slang-wasm-bindings.cpp +++ b/source/slang-wasm/slang-wasm-bindings.cpp @@ -17,6 +17,11 @@ EMSCRIPTEN_BINDINGS(slang) "getLastError", &slang::wgsl::getLastError); + function( + "getCompileTargets", + &slang::wgsl::getCompileTargets, + return_value_policy::take_ownership()); + class_("GlobalSession") .function( "createSession", @@ -40,7 +45,10 @@ EMSCRIPTEN_BINDINGS(slang) return_value_policy::take_ownership()) .function( "getEntryPointCode", - &slang::wgsl::ComponentType::getEntryPointCode); + &slang::wgsl::ComponentType::getEntryPointCode) + .function( + "getEntryPointCodeSpirv", + &slang::wgsl::ComponentType::getEntryPointCodeSpirv); class_>("Module") .function( @@ -59,5 +67,11 @@ EMSCRIPTEN_BINDINGS(slang) class_>("EntryPoint"); + class_("CompileTargets") + .function( + "findCompileTarget", + &slang::wgsl::CompileTargets::findCompileTarget, + return_value_policy::take_ownership()); + register_vector("ComponentTypeList"); } diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp index a679a5f3d1..6fbe2dc6c4 100644 --- a/source/slang-wasm/slang-wasm.cpp +++ b/source/slang-wasm/slang-wasm.cpp @@ -14,6 +14,7 @@ namespace wgsl { Error g_error; +CompileTargets g_compileTargets; Error getLastError() { @@ -22,6 +23,11 @@ Error getLastError() return currentError; } +CompileTargets* getCompileTargets() +{ + return &g_compileTargets; +} + GlobalSession* createGlobalSession() { IGlobalSession* globalSession = nullptr; @@ -38,7 +44,33 @@ GlobalSession* createGlobalSession() return new GlobalSession(globalSession); } -Session* GlobalSession::createSession() +CompileTargets::CompileTargets() +{ +#define MAKE_PAIR(x) { #x, SLANG_##x } + + m_compileTargetMap = { + MAKE_PAIR(GLSL), + MAKE_PAIR(HLSL), + MAKE_PAIR(WGSL), + MAKE_PAIR(SPIRV), + MAKE_PAIR(METAL), + }; +} + +int CompileTargets::findCompileTarget(const std::string& name) +{ + auto res = m_compileTargetMap.find(name); + if ( res != m_compileTargetMap.end()) + { + return res->second; + } + else + { + return SLANG_TARGET_UNKNOWN; + } +} + +Session* GlobalSession::createSession(int compileTarget) { ISession* session = nullptr; { @@ -46,7 +78,7 @@ Session* GlobalSession::createSession() sessionDesc.structureSize = sizeof(sessionDesc); constexpr SlangInt targetCount = 1; TargetDesc target = {}; - target.format = SLANG_WGSL; + target.format = (SlangCompileTarget)compileTarget; sessionDesc.targets = ⌖ sessionDesc.targetCount = targetCount; SlangResult result = m_interface->createSession(sessionDesc, &session); @@ -202,5 +234,32 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde return {}; } +// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val +// to wrap it and return it to the javascript side. +emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex) +{ + Slang::ComPtr kernelBlob; + Slang::ComPtr diagnosticBlob; + SlangResult result = interface()->getEntryPointCode( + entryPointIndex, + targetIndex, + kernelBlob.writeRef(), + diagnosticBlob.writeRef()); + if (result != SLANG_OK) + { + g_error.type = std::string("USER"); + g_error.result = result; + g_error.message = std::string( + (char*)diagnosticBlob->getBufferPointer(), + (char*)diagnosticBlob->getBufferPointer() + + diagnosticBlob->getBufferSize()); + return {}; + } + + const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer(); + return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), + ptr)); +} + } // namespace wgsl } // namespace slang diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h index c329716e8e..a54cfe1ead 100644 --- a/source/slang-wasm/slang-wasm.h +++ b/source/slang-wasm/slang-wasm.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include namespace slang { @@ -20,6 +22,17 @@ class Error Error getLastError(); +class CompileTargets +{ +public: + CompileTargets(); + int findCompileTarget(const std::string& name); +private: + std::unordered_map m_compileTargetMap; +}; + +CompileTargets* getCompileTargets(); + class ComponentType { public: @@ -30,6 +43,7 @@ class ComponentType ComponentType* link(); std::string getEntryPointCode(int entryPointIndex, int targetIndex); + emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex); slang::IComponentType* interface() const {return m_interface;} @@ -93,7 +107,7 @@ class GlobalSession GlobalSession(slang::IGlobalSession* interface) : m_interface(interface) {} - Session* createSession(); + Session* createSession(int compileTarget); slang::IGlobalSession* interface() const {return m_interface;}