diff --git a/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h b/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h new file mode 100644 index 00000000000000..5de0da1f52d0db --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h @@ -0,0 +1,106 @@ +//===- JITLinkRedirectableSymbolManager.h - JITLink redirection -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Redirectable Symbol Manager implementation using JITLink +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_JITLINKREDIRECABLEMANAGER_H +#define LLVM_EXECUTIONENGINE_ORC_JITLINKREDIRECABLEMANAGER_H + +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/RedirectionManager.h" +#include "llvm/Support/StringSaver.h" + +namespace llvm { +namespace orc { + +class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager, + public ResourceManager { +public: + /// Create redirection manager that uses JITLink based implementaion. + static Expected> + Create(ExecutionSession &ES, ObjectLinkingLayer &ObjLinkingLayer, + JITDylib &JD) { + Error Err = Error::success(); + auto RM = std::unique_ptr( + new JITLinkRedirectableSymbolManager(ES, ObjLinkingLayer, JD, Err)); + if (Err) + return Err; + return std::move(RM); + } + + void emitRedirectableSymbols(std::unique_ptr R, + const SymbolAddrMap &InitialDests) override; + + Error redirect(JITDylib &TargetJD, const SymbolAddrMap &NewDests) override; + + Error handleRemoveResources(JITDylib &TargetJD, ResourceKey K) override; + + void handleTransferResources(JITDylib &TargetJD, ResourceKey DstK, + ResourceKey SrcK) override; + +private: + using StubHandle = unsigned; + constexpr static unsigned StubBlockSize = 256; + constexpr static StringRef JumpStubPrefix = "$__IND_JUMP_STUBS"; + constexpr static StringRef StubPtrPrefix = "$IND_JUMP_PTR_"; + constexpr static StringRef JumpStubTableName = "$IND_JUMP_"; + constexpr static StringRef StubPtrTableName = "$__IND_JUMP_PTRS"; + + JITLinkRedirectableSymbolManager(ExecutionSession &ES, + ObjectLinkingLayer &ObjLinkingLayer, + JITDylib &JD, Error &Err) + : ES(ES), ObjLinkingLayer(ObjLinkingLayer), JD(JD), + AnonymousPtrCreator( + jitlink::getAnonymousPointerCreator(ES.getTargetTriple())), + PtrJumpStubCreator( + jitlink::getPointerJumpStubCreator(ES.getTargetTriple())) { + if (!AnonymousPtrCreator || !PtrJumpStubCreator) + Err = make_error("Architecture not supported", + inconvertibleErrorCode()); + if (Err) + return; + ES.registerResourceManager(*this); + } + + ~JITLinkRedirectableSymbolManager() { ES.deregisterResourceManager(*this); } + + StringRef JumpStubSymbolName(unsigned I) { + return *ES.intern((JumpStubPrefix + Twine(I)).str()); + } + + StringRef StubPtrSymbolName(unsigned I) { + return *ES.intern((StubPtrPrefix + Twine(I)).str()); + } + + unsigned GetNumAvailableStubs() const { return AvailableStubs.size(); } + + Error redirectInner(JITDylib &TargetJD, const SymbolAddrMap &NewDests); + Error grow(unsigned Need); + + ExecutionSession &ES; + ObjectLinkingLayer &ObjLinkingLayer; + JITDylib &JD; + jitlink::AnonymousPointerCreator AnonymousPtrCreator; + jitlink::PointerJumpStubCreator PtrJumpStubCreator; + + std::vector AvailableStubs; + using SymbolToStubMap = DenseMap; + DenseMap SymbolToStubs; + std::vector JumpStubs; + std::vector StubPointers; + DenseMap> TrackedResources; + + std::mutex Mutex; +}; + +} // namespace orc +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h b/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h new file mode 100644 index 00000000000000..87a81b0e529ccd --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h @@ -0,0 +1,101 @@ +//===- RedirectionManager.h - Redirection manager interface -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Redirection manager interface that redirects a call to symbol to another. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_REDIRECTIONMANAGER_H +#define LLVM_EXECUTIONENGINE_ORC_REDIRECTIONMANAGER_H + +#include "llvm/ExecutionEngine/Orc/Core.h" + +namespace llvm { +namespace orc { + +/// Base class for performing redirection of call to symbol to another symbol in +/// runtime. +class RedirectionManager { +public: + /// Symbol name to symbol definition map. + using SymbolAddrMap = DenseMap; + + virtual ~RedirectionManager() = default; + /// Change the redirection destination of given symbols to new destination + /// symbols. + virtual Error redirect(JITDylib &JD, const SymbolAddrMap &NewDests) = 0; + + /// Change the redirection destination of given symbol to new destination + /// symbol. + virtual Error redirect(JITDylib &JD, SymbolStringPtr Symbol, + ExecutorSymbolDef NewDest) { + return redirect(JD, {{Symbol, NewDest}}); + } + +private: + virtual void anchor(); +}; + +/// Base class for managing redirectable symbols in which a call +/// gets redirected to another symbol in runtime. +class RedirectableSymbolManager : public RedirectionManager { +public: + /// Create redirectable symbols with given symbol names and initial + /// desitnation symbol addresses. + Error createRedirectableSymbols(ResourceTrackerSP RT, + const SymbolMap &InitialDests); + + /// Create a single redirectable symbol with given symbol name and initial + /// desitnation symbol address. + Error createRedirectableSymbol(ResourceTrackerSP RT, SymbolStringPtr Symbol, + ExecutorSymbolDef InitialDest) { + return createRedirectableSymbols(RT, {{Symbol, InitialDest}}); + } + + /// Emit redirectable symbol + virtual void + emitRedirectableSymbols(std::unique_ptr MR, + const SymbolMap &InitialDests) = 0; +}; + +class RedirectableMaterializationUnit : public MaterializationUnit { +public: + RedirectableMaterializationUnit(RedirectableSymbolManager &RM, + const SymbolMap &InitialDests) + : MaterializationUnit(convertToFlags(InitialDests)), RM(RM), + InitialDests(InitialDests) {} + + StringRef getName() const override { + return "RedirectableSymbolMaterializationUnit"; + } + + void materialize(std::unique_ptr R) override { + RM.emitRedirectableSymbols(std::move(R), std::move(InitialDests)); + } + + void discard(const JITDylib &JD, const SymbolStringPtr &Name) override { + InitialDests.erase(Name); + } + +private: + static MaterializationUnit::Interface + convertToFlags(const SymbolMap &InitialDests) { + SymbolFlagsMap Flags; + for (auto [K, V] : InitialDests) + Flags[K] = V.getFlags(); + return MaterializationUnit::Interface(Flags, {}); + } + + RedirectableSymbolManager &RM; + SymbolMap InitialDests; +}; + +} // namespace orc +} // namespace llvm + +#endif diff --git a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt index 5dfd621781e446..0ee056e9f63a19 100644 --- a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -50,6 +50,8 @@ add_llvm_component_library(LLVMOrcJIT ExecutorProcessControl.cpp TaskDispatch.cpp ThreadSafeModule.cpp + RedirectionManager.cpp + JITLinkRedirectableSymbolManager.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp b/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp new file mode 100644 index 00000000000000..6d76d50271b9f2 --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp @@ -0,0 +1,179 @@ +//===-- JITLinkRedirectableSymbolManager.cpp - JITLink redirection in Orc -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h" +#include "llvm/ExecutionEngine/Orc/Core.h" + +#define DEBUG_TYPE "orc" + +using namespace llvm; +using namespace llvm::orc; + +void JITLinkRedirectableSymbolManager::emitRedirectableSymbols( + std::unique_ptr R, + const SymbolAddrMap &InitialDests) { + std::unique_lock Lock(Mutex); + if (GetNumAvailableStubs() < InitialDests.size()) + if (auto Err = grow(InitialDests.size() - GetNumAvailableStubs())) { + ES.reportError(std::move(Err)); + R->failMaterialization(); + return; + } + + JITDylib &TargetJD = R->getTargetJITDylib(); + SymbolMap NewSymbolDefs; + std::vector Symbols; + for (auto &[K, V] : InitialDests) { + StubHandle StubID = AvailableStubs.back(); + if (SymbolToStubs[&TargetJD].count(K)) { + ES.reportError(make_error( + "Tried to create duplicate redirectable symbols", + inconvertibleErrorCode())); + R->failMaterialization(); + return; + } + dbgs() << *K << "\n"; + SymbolToStubs[&TargetJD][K] = StubID; + NewSymbolDefs[K] = JumpStubs[StubID]; + NewSymbolDefs[K].setFlags(V.getFlags()); + Symbols.push_back(K); + AvailableStubs.pop_back(); + } + + if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) { + ES.reportError(std::move(Err)); + R->failMaterialization(); + return; + } + + if (auto Err = redirectInner(TargetJD, InitialDests)) { + ES.reportError(std::move(Err)); + R->failMaterialization(); + return; + } + + auto Err = R->withResourceKeyDo([&](ResourceKey Key) { + TrackedResources[Key].insert(TrackedResources[Key].end(), Symbols.begin(), + Symbols.end()); + }); + if (Err) { + ES.reportError(std::move(Err)); + R->failMaterialization(); + return; + } +} + +Error JITLinkRedirectableSymbolManager::redirect( + JITDylib &TargetJD, const SymbolAddrMap &NewDests) { + std::unique_lock Lock(Mutex); + return redirectInner(TargetJD, NewDests); +} + +Error JITLinkRedirectableSymbolManager::redirectInner( + JITDylib &TargetJD, const SymbolAddrMap &NewDests) { + std::vector PtrWrites; + for (auto &[K, V] : NewDests) { + if (!SymbolToStubs[&TargetJD].count(K)) + return make_error( + "Tried to redirect non-existent redirectalbe symbol", + inconvertibleErrorCode()); + StubHandle StubID = SymbolToStubs[&TargetJD].at(K); + PtrWrites.push_back({StubPointers[StubID].getAddress(), V.getAddress()}); + } + if (auto Err = ES.getExecutorProcessControl().getMemoryAccess().writePointers( + PtrWrites)) + return Err; + return Error::success(); +} + +Error JITLinkRedirectableSymbolManager::grow(unsigned Need) { + unsigned OldSize = JumpStubs.size(); + unsigned NumNewStubs = alignTo(Need, StubBlockSize); + unsigned NewSize = OldSize + NumNewStubs; + + JumpStubs.resize(NewSize); + StubPointers.resize(NewSize); + AvailableStubs.reserve(NewSize); + + SymbolLookupSet LookupSymbols; + DenseMap NewDefsMap; + + Triple TT = ES.getTargetTriple(); + auto G = std::make_unique( + "", TT, TT.isArch64Bit() ? 8 : 4, + TT.isLittleEndian() ? support::little : support::big, + jitlink::getGenericEdgeKindName); + auto &PointerSection = + G->createSection(StubPtrTableName, MemProt::Write | MemProt::Read); + auto &StubsSection = + G->createSection(JumpStubTableName, MemProt::Exec | MemProt::Read); + + for (size_t I = OldSize; I < NewSize; I++) { + auto Pointer = AnonymousPtrCreator(*G, PointerSection, nullptr, 0); + if (auto Err = Pointer.takeError()) + return Err; + + StringRef PtrSymName = StubPtrSymbolName(I); + Pointer->setName(PtrSymName); + Pointer->setScope(jitlink::Scope::Default); + LookupSymbols.add(ES.intern(PtrSymName)); + NewDefsMap[ES.intern(PtrSymName)] = &StubPointers[I]; + + auto Stub = PtrJumpStubCreator(*G, StubsSection, *Pointer); + if (auto Err = Stub.takeError()) + return Err; + + StringRef JumpStubSymName = JumpStubSymbolName(I); + Stub->setName(JumpStubSymName); + Stub->setScope(jitlink::Scope::Default); + LookupSymbols.add(ES.intern(JumpStubSymName)); + NewDefsMap[ES.intern(JumpStubSymName)] = &JumpStubs[I]; + } + + if (auto Err = ObjLinkingLayer.add(JD, std::move(G))) + return Err; + + auto LookupResult = ES.lookup(makeJITDylibSearchOrder(&JD), LookupSymbols); + if (auto Err = LookupResult.takeError()) + return Err; + + for (auto &[K, V] : *LookupResult) + *NewDefsMap.at(K) = V; + + for (size_t I = OldSize; I < NewSize; I++) + AvailableStubs.push_back(I); + + return Error::success(); +} + +Error JITLinkRedirectableSymbolManager::handleRemoveResources( + JITDylib &TargetJD, ResourceKey K) { + std::unique_lock Lock(Mutex); + for (auto &Symbol : TrackedResources[K]) { + if (!SymbolToStubs[&TargetJD].count(Symbol)) + return make_error( + "Tried to remove non-existent redirectable symbol", + inconvertibleErrorCode()); + AvailableStubs.push_back(SymbolToStubs[&TargetJD].at(Symbol)); + SymbolToStubs[&TargetJD].erase(Symbol); + if (SymbolToStubs[&TargetJD].empty()) + SymbolToStubs.erase(&TargetJD); + } + TrackedResources.erase(K); + + return Error::success(); +} + +void JITLinkRedirectableSymbolManager::handleTransferResources( + JITDylib &TargetJD, ResourceKey DstK, ResourceKey SrcK) { + std::unique_lock Lock(Mutex); + TrackedResources[DstK].insert(TrackedResources[DstK].end(), + TrackedResources[SrcK].begin(), + TrackedResources[SrcK].end()); + TrackedResources.erase(SrcK); +} diff --git a/llvm/lib/ExecutionEngine/Orc/RedirectionManager.cpp b/llvm/lib/ExecutionEngine/Orc/RedirectionManager.cpp new file mode 100644 index 00000000000000..cbc77c5034303a --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/RedirectionManager.cpp @@ -0,0 +1,24 @@ +//===---- RedirectionManager.cpp - Redirection manager interface in Orc ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/RedirectionManager.h" + +#define DEBUG_TYPE "orc" + +using namespace llvm; +using namespace llvm::orc; + +void RedirectionManager::anchor() {} + +Error RedirectableSymbolManager::createRedirectableSymbols( + ResourceTrackerSP RT, const SymbolMap &InitialDests) { + auto &JD = RT->getJITDylib(); + return JD.define( + std::make_unique(*this, InitialDests), + RT); +} diff --git a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt index dc3380d35fda9d..98c86d80730249 100644 --- a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -43,6 +43,7 @@ add_llvm_unittest(OrcJITTests TaskDispatchTest.cpp ThreadSafeModuleTest.cpp WrapperFunctionUtilsTest.cpp + JITLinkRedirectionManagerTest.cpp EXPORT_SYMBOLS ) diff --git a/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp new file mode 100644 index 00000000000000..0f87c1b7433238 --- /dev/null +++ b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp @@ -0,0 +1,100 @@ +#include "OrcTestCommon.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc; +using namespace llvm::jitlink; + +static int initialTarget() { return 42; } +static int middleTarget() { return 13; } +static int finalTarget() { return 53; } + +class JITLinkRedirectionManagerTest : public testing::Test { +public: + ~JITLinkRedirectionManagerTest() { + if (ES) + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + +protected: + void SetUp() override { + auto JTMB = JITTargetMachineBuilder::detectHost(); + // Bail out if we can not detect the host. + if (!JTMB) { + consumeError(JTMB.takeError()); + GTEST_SKIP(); + } + + ES = std::make_unique( + std::make_unique( + nullptr, nullptr, JTMB->getTargetTriple().getTriple())); + JD = &ES->createBareJITDylib("main"); + ObjLinkingLayer = std::make_unique( + *ES, std::make_unique(4096)); + DL = std::make_unique( + cantFail(JTMB->getDefaultDataLayoutForTarget())); + } + JITDylib *JD{nullptr}; + std::unique_ptr ES; + std::unique_ptr ObjLinkingLayer; + std::unique_ptr DL; +}; + +TEST_F(JITLinkRedirectionManagerTest, BasicRedirectionOperation) { + auto RM = + JITLinkRedirectableSymbolManager::Create(*ES, *ObjLinkingLayer, *JD); + // Bail out if we can not create + if (!RM) { + consumeError(RM.takeError()); + GTEST_SKIP(); + } + + auto DefineTarget = [&](StringRef TargetName, ExecutorAddr Addr) { + SymbolStringPtr Target = ES->intern(TargetName); + cantFail(JD->define(std::make_unique( + SymbolFlagsMap({{Target, JITSymbolFlags::Exported}}), + [&](std::unique_ptr R) -> void { + // No dependencies registered, can't fail. + cantFail( + R->notifyResolved({{Target, {Addr, JITSymbolFlags::Exported}}})); + cantFail(R->notifyEmitted()); + }))); + return cantFail(ES->lookup({JD}, TargetName)); + }; + + auto InitialTarget = + DefineTarget("InitialTarget", ExecutorAddr::fromPtr(&initialTarget)); + auto MiddleTarget = + DefineTarget("MiddleTarget", ExecutorAddr::fromPtr(&middleTarget)); + auto FinalTarget = + DefineTarget("FinalTarget", ExecutorAddr::fromPtr(&finalTarget)); + + auto RedirectableSymbol = ES->intern("RedirectableTarget"); + EXPECT_THAT_ERROR( + (*RM)->createRedirectableSymbols(JD->getDefaultResourceTracker(), + {{RedirectableSymbol, InitialTarget}}), + Succeeded()); + auto RTDef = cantFail(ES->lookup({JD}, RedirectableSymbol)); + + auto RTPtr = RTDef.getAddress().toPtr(); + auto Result = RTPtr(); + EXPECT_EQ(Result, 42) << "Failed to call initial target"; + + EXPECT_THAT_ERROR((*RM)->redirect(*JD, {{RedirectableSymbol, MiddleTarget}}), + Succeeded()); + Result = RTPtr(); + EXPECT_EQ(Result, 13) << "Failed to call middle redirected target"; + + EXPECT_THAT_ERROR((*RM)->redirect(*JD, {{RedirectableSymbol, FinalTarget}}), + Succeeded()); + Result = RTPtr(); + EXPECT_EQ(Result, 53) << "Failed to call redirected target"; +}